diff --git a/CHANGELOG.md b/CHANGELOG.md index 828c8e03ac..039fc805da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,19 @@ # NVIDIA CUTLASS Changelog +## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03) + +* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision. +* New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue. +* [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release. +* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). +* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. +* [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue. +* New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here. +* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0. +* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads) +* Updates and bugfixes from the community (thanks!) + ## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14) * New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python). * New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper. diff --git a/CMakeLists.txt b/CMakeLists.txt index 16e0f9ae5e..eba7ecdc2c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") -project(CUTLASS VERSION 3.1.0 LANGUAGES CXX) +project(CUTLASS VERSION 3.2.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -181,8 +181,8 @@ if(WIN32) endif() if (WIN32) - # Enable more warnings and treat as errors - list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX) + # Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors. + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3) # Disable warning on Unicode characters list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819) @@ -376,6 +376,27 @@ if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18) cmake_policy(SET CMP0104 NEW) endif() +if (MSVC) + + # MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard + # because MSVC is not a completely compliant implementation. This option forces MSVC to use the + # appropriate value given the requested --std option. This fixes a compilation issue mismatch + # between GCC/Clang and MSVC. + # + # error : a constexpr function cannot have a nonliteral return type "dim3" + # + # See https://developercommunity.visualstudio.com/t/msvc-incorrectly-defines-cplusplus/139261 + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus") + +endif() + +# Some tests require this build option in order to link. +if (MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj") +endif() + function(cutlass_apply_cuda_gencode_flags TARGET) set(options) set(oneValueArgs) @@ -490,7 +511,8 @@ endfunction() # GLOB for CUTLASS header files. Should we use a static list instead? file(GLOB_RECURSE CUTLASS_INCLUDE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} include/cutlass/*.h) -file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h) +file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h include/cutlass/*.hpp include/cutlass/*.inl) +file(GLOB_RECURSE CUTLASS_CUTE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cute/*.h*) file(GLOB_RECURSE CUTLASS_NVRTC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/test test/unit/nvrtc/kernel/*.h) ################################################################################################### @@ -647,7 +669,7 @@ endif() ################################################################################ -set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.config.cmake) +set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake) set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "") function(cutlass_add_executable_tests NAME TARGET) @@ -678,6 +700,9 @@ function(cutlass_add_executable_tests NAME TARGET) set(__DISABLE_TESTS OFF) endif() + set(TEST_EXE $) + set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR}) + if (__RESULT_CACHE_FILE) add_custom_command( @@ -722,6 +747,16 @@ function(cutlass_add_executable_tests NAME TARGET) endforeach() endif() + if (CUTLASS_INSTALL_TESTS) + + set(_INLINE_PER_TEST_CODE) + + file(READ "${PROJECT_SOURCE_DIR}/cmake/CTestTestfile.test.configure.cmake" _INLINE_PER_TEST_CODE_TEMPLATE) + + endif() + + set(TEST_GROUP_NAME ${NAME}) + foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS) if (CMD_COUNT GREATER 1) @@ -756,41 +791,47 @@ function(cutlass_add_executable_tests NAME TARGET) add_dependencies(${DEPENDEE} ${TEST_NAME}) endforeach() - add_test( - NAME c${TEST_NAME} - COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ ${TEST_COMMAND_OPTIONS} - ) + set(TEST_NAME c${TEST_NAME}) + string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY) + string(APPEND _INLINE_PER_TEST_CODE "${_TEST_CODE}") - set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS}) + endforeach() - if (CUTLASS_INSTALL_TESTS) + # To run the tests from an install package with tests enabled, we need to generate test files + # that don't rely on the current directory structure in build. - # To run the tests from an install package with tests enabled, we need to generate test files - # that don't rely on the current directory structure in build. + set(TEST_NAME c${NAME}) + set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) + file(MAKE_DIRECTORY ${TEST_GEN_DIR}) - set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/${NAME}) - file(MAKE_DIRECTORY ${TEST_GEN_DIR}) + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT ON) + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) - set(TEST_NAME c${TEST_NAME}) - set(TEST_EXE $) - set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR}) - configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" @ONLY) + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format. + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY) - file(GENERATE - OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" - INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" - ) - - install( - FILES "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" - DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/ - ) - - set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "") - - endif() + # The following line imports the tests for immediate run via `make test`. - endforeach() + include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake) + + set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "") + + if (CUTLASS_INSTALL_TESTS) + + file(GENERATE + OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake" + INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" + ) + + install( + FILES "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake" + DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/${TEST_NAME} + RENAME CTestTestfile.${TEST_NAME}.cmake + ) + + endif() endfunction() @@ -813,33 +854,20 @@ endif() if (CUTLASS_INSTALL_TESTS) - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/cmake") + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest") - file(WRITE "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "# Generated File\n") + file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n") foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES}) - file(APPEND "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "include(${GENERATED_FILE})\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n") endforeach() install( - FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" + FILES "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/" ) endif() -#? install( -#? FILES ${CMAKE_BINARY_DIR}/CTestTestfile.cmake -#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ -#? ) -#? -#? install( -#? DIRECTORY -#? ${CMAKE_BINARY_DIR}/tools -#? ${CMAKE_BINARY_DIR}/test -#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ -#? FILES_MATCHING PATTERN "CTestTestfile.cmake" -#? ) - ################################################################################ include(CMakePackageConfigHelpers) @@ -866,4 +894,3 @@ install( include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake) - diff --git a/CUDA.cmake b/CUDA.cmake index 4ca903b674..32bd8a58b4 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -228,7 +228,14 @@ else() endif() set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation") -set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files") + +if (MSVC) + set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 8) +else() + set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 16) +endif() + +set(CUTLASS_UNITY_BUILD_BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT} CACHE STRING "Batch size for unified source files") function(cutlass_unify_source_files TARGET_ARGS_VAR) diff --git a/README.md b/README.md index e6c0f923e7..7ed86c117f 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.1 +# CUTLASS 3.2 -_CUTLASS 3.1 - April 2023_ +_CUTLASS 3.2 - August 2023_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -41,33 +41,17 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -# What's New in CUTLASS 3.1 - -CUTLASS 3.1 is an update to CUTLASS adding: - -- New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](/python/README.md) and new [examples](/examples/python). -- New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper. -- Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues. -- New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. -- New [*warp-specialized persistent cooperative*](include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that improves performance on Hopper. -- An [example](examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper. -- New Epilogue builders. Similar to mainloop builders (see [example 49](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization. -- Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler. -- Changes to the [GEMM API 3.x](media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. -- [FMHA Backward Pass](examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers. -- [Streamk GEMM with Broadcast](examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM. -- [Batched B2B GEMM](examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel. -- [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix. -- [Permute + GEMM fusion](examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue. -- [Row Broadcast](include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue. - -- *Announcement*: - - The GitHub branch is renamed from `master` to `main` in this release. - - A slight modification has been made to the ordering of arguments passed in to epilogues in 3.x kernels. - Existing CUTLASS 3.x kernel invocations will need to be modified to reflect this change. 2.x kernels - remain unaffected. See [#890](https://github.com/NVIDIA/cutlass/issues/890) for additional information. - - The CUTLASS Python interface supersedes PyCUTLASS. PyCUTLASS has been moved to [/python/cutlass/backend](/python/cutlass/backend). - Backward compatibility between the Python interface and PyCUTLASS will not be maintained moving forward. +# What's New in CUTLASS 3.2 + +CUTLASS 3.2 is an update to CUTLASS adding: +- New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). +- New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue. +- [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release. +- Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). +- Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. +- [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue. +- New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here. +- Support for Windows (MSVC) builds. Minimum requirements: @@ -111,8 +95,8 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA # Compatibility CUTLASS requires a C++17 host compiler and -performs best when built with the [**CUDA 12.1 Toolkit**](https://developer.nvidia.com/cuda-toolkit). -It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, and CUDA 12.0. +performs best when built with the [**CUDA 12.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit). +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1. ## Operating Systems We have tested the following environments. @@ -122,8 +106,9 @@ We have tested the following environments. | Ubuntu 18.04 | GCC 7.5.0 | | Ubuntu 20.04 | GCC 10.3.0 | | Ubuntu 22.04 | GCC 11.2.0 | +| Windows 10.0 | Visual Studio 2019 v16.11.27 | -Note: We plan to add Windows (MSVC) & Clang compiler support soon. +Note: We plan to add Clang compiler support soon. Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. ## Hardware diff --git a/cmake/CTestTestfile.config.cmake b/cmake/CTestTestfile.config.cmake deleted file mode 100644 index 0705b19c12..0000000000 --- a/cmake/CTestTestfile.config.cmake +++ /dev/null @@ -1,21 +0,0 @@ -# Generated file - -if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) - set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) -else() - set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@) -endif() - -if (NOT "@TEST_EXE_DIR@" STREQUAL "") - set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@) -else() - set(TEST_EXE_PATH @TEST_EXE@) -endif() - -add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) - -if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "") - set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@") -endif() - -set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) diff --git a/cmake/CTestTestfile.configure.cmake b/cmake/CTestTestfile.configure.cmake new file mode 100644 index 0000000000..3fc3994647 --- /dev/null +++ b/cmake/CTestTestfile.configure.cmake @@ -0,0 +1,14 @@ +# Generated file + +set(TEST_EXE_PATH @TEST_EXE_PATH@) +set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@) +set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@) + +if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) + set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) +else() + set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@) +endif() + +@_INLINE_PER_TEST_CODE@ + diff --git a/cmake/CTestTestfile.test.configure.cmake b/cmake/CTestTestfile.test.configure.cmake new file mode 100644 index 0000000000..dad2c76cf5 --- /dev/null +++ b/cmake/CTestTestfile.test.configure.cmake @@ -0,0 +1,15 @@ +if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT) + # The longform/extended format allows generator expressions to be + # expanded property and is useful in contexts where the files need + # to be immediately included into being-processed cmake code. + add_test(NAME @TEST_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) +else() + add_test(@TEST_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) +endif() + +if (TEST_EXE_WORKING_DIRECTORY) + set_tests_properties(@TEST_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}") +endif() + +set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) + diff --git a/examples/10_planar_complex/CMakeLists.txt b/examples/10_planar_complex/CMakeLists.txt index 11ca9724ec..1eb55f1368 100644 --- a/examples/10_planar_complex/CMakeLists.txt +++ b/examples/10_planar_complex/CMakeLists.txt @@ -27,7 +27,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +# +# This example depends on the CUTLASS Library +# +if (CUTLASS_ENABLE_LIBRARY) # Planar Complex GEMM example cutlass_example_add_executable( @@ -35,11 +38,6 @@ cutlass_example_add_executable( planar_complex.cu ) - -# -# This example depends on the CUTLASS Library -# - target_link_libraries( 10_planar_complex PRIVATE @@ -48,3 +46,4 @@ target_link_libraries( cuda ) +endif() diff --git a/examples/11_planar_complex_array/CMakeLists.txt b/examples/11_planar_complex_array/CMakeLists.txt index 64125b5256..aad6d4422c 100644 --- a/examples/11_planar_complex_array/CMakeLists.txt +++ b/examples/11_planar_complex_array/CMakeLists.txt @@ -27,7 +27,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +# +# This example depends on the CUTLASS Library +# +if (CUTLASS_ENABLE_LIBRARY) # Planar Complex Array GEMM example cutlass_example_add_executable( @@ -35,11 +38,6 @@ cutlass_example_add_executable( planar_complex_array.cu ) - -# -# This example depends on the CUTLASS Library -# - target_link_libraries( 11_planar_complex_array PRIVATE @@ -48,3 +46,4 @@ target_link_libraries( cuda ) +endif() diff --git a/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h index 267423d433..ff38a0d1de 100644 --- a/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h @@ -447,4 +447,4 @@ struct B2bFusedGroupedGemmRun }; -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu index 4abaee51b2..58f18e023c 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu @@ -294,4 +294,4 @@ int main(int argc, char const **args) { -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h index a8eafaadc9..0e7d498a40 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h @@ -154,4 +154,4 @@ struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor< } // namespace gemm } // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/kernel/grouped.h b/examples/13_two_tensor_op_fusion/kernel/grouped.h index 7b6c9504d7..207a711344 100644 --- a/examples/13_two_tensor_op_fusion/kernel/grouped.h +++ b/examples/13_two_tensor_op_fusion/kernel/grouped.h @@ -165,4 +165,4 @@ struct GroupedKernel { } // namespace gemm } // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h index cb409157cc..42ef4110a8 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h +++ b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h @@ -150,4 +150,4 @@ struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle< } // namespace threadblock } // namespace gemm -} // namespace cutlass \ No newline at end of file +} // namespace cutlass diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index 3ae92c36c1..cbbadeda1d 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -316,7 +316,11 @@ int run(Options &options) { // <- Fill tensor_b_indices on host with unique random integers std::vector to_fill(problem_size.n()) ; // vector with ints. std::iota (std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n() - std::random_shuffle(to_fill.begin(), to_fill.end()); + { // std::random_shuffle was deprecated in C++14 and removed in C++17 + std::random_device make_seed; + std::mt19937 source_of_randomness(make_seed()); + std::shuffle(to_fill.begin(), to_fill.end(), source_of_randomness); + } memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int)); // Copy data from host to GPU diff --git a/examples/39_gemm_permute/gemm_permute.cu b/examples/39_gemm_permute/gemm_permute.cu index 84e9052c55..ede6b62eee 100644 --- a/examples/39_gemm_permute/gemm_permute.cu +++ b/examples/39_gemm_permute/gemm_permute.cu @@ -283,8 +283,7 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -namespace detail -{ +namespace { // (anonymous) /// Dimension-generic permutation loop template @@ -305,7 +304,7 @@ void permute_host_impl( } } -} // namespace detail +} // namespace (anonymous) /// Perform a reference (host-based) permutation of an input tensor template @@ -332,7 +331,7 @@ void permute_host( cutlass::TensorView view_output(h_output.data(), TensorLayout::packed(shape_perm), shape_perm); decltype(shape_orig) coord; - detail::permute_host_impl<0>(view_input, view_output, Info::permute, coord); + permute_host_impl<0>(view_input, view_output, Info::permute, coord); cutlass::device_memory::copy_to_device(output.data(), h_output.data(), num_elems); } diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu index f71d315acb..db2eff51f3 100644 --- a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @@ -461,6 +461,11 @@ Result run(std::string description, Options &options) std::cout << " GFLOPs: " << result.gflops << std::endl; } + // TODO: uncomment when results match + //if (!result.passed) { + // exit(-1); + //} + return result; } diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index 2c7d0ba910..194c400554 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -293,10 +293,10 @@ bool initialize_block( /// Initialize operands to be used in the GEMM and reference GEMM void initialize(const Options &options) { - stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{})); - stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{})); - stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{})); - stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{})); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, Int<1>{})); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, Int<1>{})); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, Int<1>{})); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, Int<1>{})); block_A.reset(options.m * options.k); block_B.reset(options.k * options.n); diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index 6bbdfb6a93..25f637ac49 100644 --- a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -106,6 +106,13 @@ This example also illustrates how CUTLASS 3 GEMMs targeting Hopper automatically support batched GEMMs by simply extending the problem size with an additional tensor rank. + CUTLASS 3.2 provides initial support for epilogue visitor trees (EVT) for the TMA warp-specialized collective. + EVTs allow users to define their own customized epilogue fusion patterns without having to write a new + collective epilogue. This is done by representing the fusion as a compute graph, where each node is one of a + fundamental set of load, store, or compute operations. These operations are either elementwise for tensor + inputs/outputs, broadcasts for vector/scalar inputs, or reductions for vector/scalar outputs. + This example shows how users can define their own custom EVT and use it with the CollectiveBuilder. + Example usage: $ ./examples/49_hopper_with_collective_builder/49_collective_builder \ --m=2048 --n=2048 --k=2048 --l=2 @@ -124,6 +131,7 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" #include "cutlass/util/command_line.h" #include "cutlass/util/distribution.h" @@ -234,7 +242,7 @@ bool initialize_block( // For example, if `KernelScheduleAuto` is used for the mainloop builder, `EpilogueScheduleAuto` must // be used for the epilogue builder. // -// Furthermore, if an override schedule is selected, both epilgoue and mainloop schedules must +// Furthermore, if an override schedule is selected, both epilogue and mainloop schedules must // be specifically opt into a compatible selection. // // Behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases @@ -245,7 +253,11 @@ template < // Type of epilogue schedule to generate class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, // Number of pipeline stages to use - class StageCountType = cutlass::gemm::collective::StageCountAuto + class StageCountType = cutlass::gemm::collective::StageCountAuto, + // Type of tile scheduler to use + class TileSchedulerType = cutlass::gemm::PersistentScheduler, + // Do we use custom epilogue visitor tree (EVT) fusion + bool UseCustomEVT = false > struct ExampleRunner { @@ -254,28 +266,62 @@ struct ExampleRunner { using LayoutC = cutlass::layout::ColumnMajor; using LayoutD = cutlass::layout::ColumnMajor; - static constexpr int AlignmentA = 8; - static constexpr int AlignmentB = 8; - static constexpr int AlignmentC = 8; - static constexpr int AlignmentD = 8; + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + // 16B alignment lets us use TMA + static constexpr int AlignmentA = 16 / sizeof(ElementA); + static constexpr int AlignmentB = 16 / sizeof(ElementB); + static constexpr int AlignmentC = 16 / sizeof(ElementC); + static constexpr int AlignmentD = 16 / sizeof(ElementD); + + static_assert(not UseCustomEVT || + (cute::is_same_v || + cute::is_same_v), + "Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + + // EVTs can be constructed by composing the fundamental load/store/compute visitor operations defined in include/cutlass/epilogue/fusion + // For more complex examples of EVT construction please refer to include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp + using CustomEVT = // alpha * acc + beta * C + cutlass::epilogue::fusion::Sm90EVT, // beta * C + (alpha * acc) + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // beta + cutlass::epilogue::fusion::Sm90SrcFetch, // C + cutlass::epilogue::fusion::Sm90EVT, // alpha * acc + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // alpha + cutlass::epilogue::fusion::Sm90AccFetch // acc + > + >; + + // A predefined set of fusion operations (implemented with EVT) are supported by the TMA warp-specialized epilogue. + // Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp + // to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts. + // These tags also provide additional metadata that can be queried at compile time. + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, Shape<_128,_128,_64>, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::half_t, LayoutC, AlignmentC, - cutlass::half_t, LayoutD, AlignmentD, - EpilogueScheduleType + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueScheduleType, + cute::conditional_t >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, AlignmentA, - cutlass::half_t, LayoutB, AlignmentB, - float, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, Shape<_128,_128,_64>, Shape<_2,_1,_1>, - std::conditional_t, + cute::conditional_t, cutlass::gemm::collective::StageCountAutoCarveout<(int)sizeof(typename CollectiveEpilogue::SharedStorage)>, StageCountType>, MainloopScheduleType @@ -284,7 +330,8 @@ struct ExampleRunner { using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, - CollectiveEpilogue + CollectiveEpilogue, + TileSchedulerType >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; @@ -315,8 +362,8 @@ struct ExampleRunner { cutlass::DeviceAllocation block_A; cutlass::DeviceAllocation block_B; cutlass::DeviceAllocation block_C; - cutlass::DeviceAllocation block_D; - cutlass::DeviceAllocation block_ref_D; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; // // Methods @@ -332,15 +379,15 @@ struct ExampleRunner { cutlass::reference::device::GemmComplex( {M, N, K}, - typename Gemm::EpilogueOutputOp::ElementCompute(alpha), + ElementScalar(alpha), ref_A, cutlass::ComplexTransform::kNone, ref_B, cutlass::ComplexTransform::kNone, - typename Gemm::EpilogueOutputOp::ElementCompute(beta), + ElementScalar(beta), ref_C, ref_D, - typename Gemm::EpilogueOutputOp::ElementAccumulator(0.f), + ElementAccumulator(0), L, // batch_count M * K, // batch_stride_A K * N, // batch_stride_B @@ -366,10 +413,10 @@ struct ExampleRunner { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); block_A.reset(M * K * L); block_B.reset(K * N * L); @@ -391,10 +438,36 @@ struct ExampleRunner { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, {block_A.get(), stride_A, block_B.get(), stride_B}, - {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + {{}, // epilogue.thread + block_C.get(), stride_C, block_D.get(), stride_D}, hw_info }; + // Custom EVT fusions will have nested unnamed args, the structure of which + // can be deduced from the type definition of the EVT. + // Each node's arguments has the recursive structure of + // {first_child_args, ..., last_child_args, op_args}, + // For more complex examples of EVT initialization please refer to + // include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp + if constexpr (UseCustomEVT) { + arguments.epilogue.thread = + { // ternary op : beta * C + (alpha * acc) + {{options.beta}}, // leaf op+args : beta + {}, // leaf op+args : C + { // binary op : alpha * acc + {{options.alpha}}, // leaf op+args : alpha + {}, // leaf op+args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + // Pre-defined fusions will have flat, named args for user-friendlyness + else { + arguments.epilogue.thread.alpha = options.alpha; + arguments.epilogue.thread.beta = options.beta; + } + Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -531,19 +604,19 @@ int main(int argc, char const **args) { print_result("Automatically-selected schedule with 5 stages", passed); // One can also override the scheduling policy to use. In this case, use the KernelTma scheduling - // policy, which specifies that the Hopper TMA feature should be used, and we also use an epilgoue + // policy, which specifies that the Hopper TMA feature should be used, and we also use an epilogue // that does not use any shared memory. ExampleRunner tma_schedule_auto_stage_runner; passed = tma_schedule_auto_stage_runner.run(options, hw_info); print_result("TMA schedule with automatically-selected stage count", passed); // Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized - // scheduling policy, and an epilgoue that does not use any shared memory. + // scheduling policy, and an epilogue that does not use any shared memory. ExampleRunner ws_schedule_auto_stage_runner; passed = ws_schedule_auto_stage_runner.run(options, hw_info); print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed); - // Finally, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized + // Here, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized // scheduling policy, TMA-based epilogue, leveraging persistent thread blocks. ExampleRunner< cutlass::gemm::KernelTmaWarpSpecializedPingpong, @@ -551,6 +624,27 @@ int main(int argc, char const **args) { passed = ws_pingpong_schedule_auto_stage_runner.run(options, hw_info); print_result("Ping-pong warp-specialized TMA schedule with automatically-selected stage count", passed); + // Here, we override the scheduling policy to use stream-K problem decomposition atop the cooperative + // warp-specialized scheduling policy. This kernel continues to leverage persistent thread blocks + // as well aso TMA in both the mainloop and epilogue. + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::StreamKScheduler> ws_cooperative_stream_k_schedule_auto_stage_runner; + passed = ws_cooperative_stream_k_schedule_auto_stage_runner.run(options, hw_info); + print_result("Cooperative warp-specialized TMA schedule using stream-K with automatically-selected stage count", passed); + + // Here, we override the fusion operation to use a customized EVT fusion, in addition to the previous schedule overrides + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::PersistentScheduler, + true> ws_cooperative_schedule_auto_stage_custom_evt_runner; + passed = ws_cooperative_schedule_auto_stage_custom_evt_runner.run(options, hw_info); + print_result("Cooperative warp-specialized TMA schedule using custom epilogue visitor tree with automatically-selected stage count", passed); + #endif return 0; diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu index f1595e5db9..d0e2568f4e 100644 --- a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu @@ -262,10 +262,10 @@ struct ExampleRunner { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto [M, N, K, L] = problem_shape_MNKL; - stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); block_A.reset(M * K * L); block_B.reset(K * N * L); diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu new file mode 100644 index 0000000000..110c6e44b1 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -0,0 +1,686 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example of a Hopper gather+GEMM+scatter kernel fusion. + + This example fuses gather before GEMM and scatter after GEMM into the same + GEMM kernel. Gather and scatter operation is controled by an index vector + to select rows or columns from A, B, C or D matrices. + + Gather/scatter operations are always performed along a strided dimension + in order to preserve vectorized loads/stores. Thus the index vector is + applied to rows of row-major matrices and columns of column-major matrices. + + Note that the index vector must contain integers in range [0,X) where + X is one of (M,N,K), depending on selected gather dimension. The problem + shape given to the GEMM kernel must consist of matrix sizes AFTER gather + and BEFORE scatter operations are applied. +*/ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" +#include "gather_gemm.hpp" +#include "gather_kernel.cuh" +#include "scatter_epilogue.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +namespace example { + +// Command line options parsing +struct Options { + + bool help = false; + + cutlass::gemm::BatchedGemmCoord problem_size = {2048, 2048, 2048, 1}; + int index_size = 1024; + int mode = 1; // N-mode gather/scatter by default + + float alpha = 1.0f; + float beta = 1.0f; + + bool reference_check = true; + int iterations = 20; + + bool valid() const { + return problem_size.m() > 0 + && problem_size.n() > 0 + && problem_size.k() > 0 + && problem_size.batch() > 0 + && 0 <= mode && mode < 3 + && index_size <= problem_size.at(mode) + && iterations > 0; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch_size", problem_size.batch()); + cmd.get_cmd_line_argument("index_size", index_size); + + char const modes[] = {'m', 'n', 'k'}; + char mode_input = modes[mode]; + cmd.get_cmd_line_argument("mode", mode_input); + mode = int(std::distance(std::begin(modes), std::find(std::begin(modes), std::end(modes), mode_input))); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("check", reference_check, true); + cmd.get_cmd_line_argument("iterations", iterations); + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << + "52_hopper_gather_scatter_fusion example\n" + "\n" + " This example uses the CUTLASS Library to fuse gather/scatter of input/output tensors with GEMM.\n" + " It validates and benchmarks the fused kernel against an unfused implementation that executes\n" + " gather+GEMM+scatter in sequence and writes intermediate (gathered) tensors to memory.\n" + " For the unfused implementation two GEMM kernels are considered: default one that uses the same\n" + " schedule and instruction set as the fused one, and an optimized one that utilizes advanced\n" + " features (such as TMA units) that cannot be used by the fused kernel due to hardware constraints." + "\n" + "Options:\n" + " --help If specified, displays this usage statement.\n" + " --m= GEMM M dimension\n" + " --n= GEMM N dimension\n" + " --k= GEMM K dimension\n" + " --batch_size= GEMM batch size\n" + " --index_size= Size of N dimension gather/scatter index\n" + " --mode= Gather mode (M, N, or K)\n" + " --alpha= GEMM alpha parameter\n" + " --beta= GEMM beta parameter\n" + " --iterations= Number of profiling iterations to perform.\n" + "\n" + "Examples:\n" + "\n" + "$ ./examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion --m=1024 --n=2048 --k=1024 --mode=n --index_size=1024\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ExampleRunner +{ + // Useful aliases + + // Alias to for the epilogue type that supports gather/scatter + using Epilogue = cutlass::epilogue::collective::EpilogueGatherScatter< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination< + ElementD, 1, + ElementAccumulator, ElementComputeEpilogue, + cutlass::epilogue::thread::ScaleType::Default, + cutlass::FloatRoundStyle::round_to_nearest, ElementC + >, + cutlass::gemm::EpilogueDefault, + GatherC, + ScatterD + >; + + // Alias to for the mainloop type + using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 128 / cutlass::sizeof_bits::value, + ElementB, LayoutB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + Shape<_128,_128,_64>, + Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCount<5>, + cutlass::gemm::KernelMultistage + >::CollectiveOp; + + using ProblemShape = Shape; + + using Kernel = cutlass::gemm::kernel::GemmGather< + ProblemShape, + Mainloop, + Epilogue, + GatherA, + GatherB + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using StrideA = typename Kernel::StrideA; + using StrideB = typename Kernel::StrideB; + using StrideC = typename Kernel::StrideC; + using StrideD = typename Kernel::StrideD; + + static constexpr bool DoGatherA = not cutlass::platform::is_same::value; + static constexpr bool DoGatherB = not cutlass::platform::is_same::value; + static constexpr bool DoGatherC = not cutlass::platform::is_same::value; + static constexpr bool DoScatterD = not cutlass::platform::is_same::value; + + static constexpr bool GatherAonM = DoGatherA && cutlass::platform::is_same::value; + static constexpr bool GatherAonK = DoGatherA && cutlass::platform::is_same::value; + static constexpr bool GatherBonN = DoGatherB && cutlass::platform::is_same::value; + static constexpr bool GatherBonK = DoGatherB && cutlass::platform::is_same::value; + static constexpr bool GatherConM = DoGatherC && cutlass::platform::is_same::value; + static constexpr bool GatherConN = DoGatherC && cutlass::platform::is_same::value; + static constexpr bool ScatterDonM = DoScatterD && cutlass::platform::is_same::value; + static constexpr bool ScatterDonN = DoScatterD && cutlass::platform::is_same::value; + + static constexpr bool GatherModeM = GatherAonM || GatherConM || ScatterDonM; + static constexpr bool GatherModeN = GatherBonN || GatherConN || ScatterDonN; + static constexpr bool GatherModeK = GatherAonK || GatherBonK; + + static_assert( GatherModeM && !GatherModeN && !GatherModeK || + !GatherModeM && GatherModeN && !GatherModeK || + !GatherModeM && !GatherModeN && GatherModeK, + "Only one gather mode (M, N or K) is supported by example runner"); + + // Construct a reference (non-gather) GEMM kernel type + + using MainloopRef = Mainloop; + + using EpilogueRef = typename cutlass::epilogue::collective::DefaultEpilogue< + StrideC, StrideD, + typename Epilogue::ThreadEpilogueOp, + typename Epilogue::EpilogueSchedule + >; + + using KernelRef = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + MainloopRef, + EpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + // Construct an optimized reference GEMM kernel type (using TMA) + + using EpilogueOpt = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, + Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementComputeEpilogue, + ElementC, LayoutC, 128 / cutlass::sizeof_bits::value, + ElementD, LayoutD, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using MainloopOpt = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 128 / cutlass::sizeof_bits::value, + ElementB, LayoutB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + Shape<_128,_128,_64>, + Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using KernelOpt = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + MainloopOpt, + EpilogueOpt + >; + + using GemmOpt = cutlass::gemm::device::GemmUniversalAdapter; + + // Data members + + cutlass::gemm::BatchedGemmCoord problem_size_orig; + cutlass::gemm::BatchedGemmCoord problem_size; + ProblemShape problem_shape_orig; + ProblemShape problem_shape; + cutlass::KernelHardwareInfo hw_info; + + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + + StrideA stride_A_orig; + StrideB stride_B_orig; + StrideC stride_C_orig; + StrideD stride_D_orig; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + + cutlass::device_memory::allocation tensor_a; + cutlass::device_memory::allocation tensor_b; + cutlass::device_memory::allocation tensor_c; + cutlass::device_memory::allocation tensor_d; + + cutlass::device_memory::allocation gather_indices; + + cutlass::device_memory::allocation tensor_a_gathered; + cutlass::device_memory::allocation tensor_b_gathered; + cutlass::device_memory::allocation tensor_c_gathered; + cutlass::device_memory::allocation tensor_d_gathered; + cutlass::device_memory::allocation tensor_d_reference; + + cutlass::gemm::GemmUniversalMode gemm_mode; + + Gemm gemm; + typename Gemm::Arguments arguments; + cutlass::device_memory::allocation workspace; + + GemmRef gemm_ref; + typename GemmRef::Arguments arguments_ref; + cutlass::device_memory::allocation workspace_ref; + + GemmOpt gemm_opt; + typename GemmOpt::Arguments arguments_opt; + cutlass::device_memory::allocation workspace_opt; + + ExampleRunner(Options const &options, cutlass::KernelHardwareInfo const &hw_info) + : problem_size_orig(options.problem_size), + problem_size(GatherModeM ? options.index_size : problem_size_orig.m(), + GatherModeN ? options.index_size : problem_size_orig.n(), + GatherModeK ? options.index_size : problem_size_orig.k(), + problem_size_orig.batch()), + problem_shape_orig(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()), + problem_shape(problem_size.m(), problem_size.n(), problem_size.k(), problem_size.batch()), + hw_info(hw_info), + alpha(options.alpha), + beta(options.beta), + stride_A_orig(cutlass::make_cute_packed_stride( + StrideA{}, make_shape(problem_size_orig.m(), problem_size_orig.k(), problem_size_orig.batch()))), + stride_B_orig(cutlass::make_cute_packed_stride( + StrideB{}, make_shape(problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()))), + stride_C_orig(cutlass::make_cute_packed_stride( + StrideC{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))), + stride_D_orig(cutlass::make_cute_packed_stride( + StrideD{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))), + stride_A(cutlass::make_cute_packed_stride( + StrideA{}, make_shape(problem_size.m(), problem_size.k(), problem_size.batch()))), + stride_B(cutlass::make_cute_packed_stride( + StrideB{}, make_shape(problem_size.n(), problem_size.k(), problem_size.batch()))), + stride_C(cutlass::make_cute_packed_stride( + StrideC{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))), + stride_D(cutlass::make_cute_packed_stride( + StrideD{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))), + tensor_a(problem_size_orig.m() * problem_size_orig.k() * problem_size_orig.batch()), + tensor_b(problem_size_orig.k() * problem_size_orig.n() * problem_size_orig.batch()), + tensor_c(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()), + tensor_d(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()), + gather_indices(options.index_size), + tensor_a_gathered(problem_size.m() * problem_size.k() * problem_size_orig.batch()), + tensor_b_gathered(problem_size.k() * problem_size.n() * problem_size_orig.batch()), + tensor_c_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()), + tensor_d_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()), + tensor_d_reference(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()), + gemm_mode(problem_size.batch() > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), + gemm(), + // When constructing arguments for gather/scatter gemm, we must pass stride arguments + // made for the original (non-gathered) problem size, because they are used to access + // tensors of the original shape. However we still use the reduced (gathered) problem + // shape since it corresponds to the logical indexing in reduced size GEMM. + arguments{ + gemm_mode, + problem_shape, + { + tensor_a.get(), + stride_A_orig, + tensor_b.get(), + stride_B_orig + }, + { + { alpha, beta }, + tensor_c.get(), stride_C_orig, + tensor_d.get(), stride_D_orig, + typename Epilogue::GatherC {gather_indices.get()}, + typename Epilogue::ScatterD{gather_indices.get()} + }, + hw_info, + typename Kernel::GatherA{gather_indices.get()}, + typename Kernel::GatherB{gather_indices.get()} + }, + workspace(Gemm::get_workspace_size(arguments)), + gemm_ref(), + arguments_ref{ + gemm_mode, + problem_shape, + { + DoGatherA ? tensor_a_gathered.get() : tensor_a.get(), + stride_A, + DoGatherB ? tensor_b_gathered.get() : tensor_b.get(), + stride_B + }, + { + { alpha, beta }, + DoGatherC ? tensor_c_gathered.get() : tensor_c.get(), + stride_C, + DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(), + stride_D + }, + hw_info + }, + workspace_ref(GemmRef::get_workspace_size(arguments_ref)), + gemm_opt(), + arguments_opt{ + gemm_mode, + problem_shape, + { + DoGatherA ? tensor_a_gathered.get() : tensor_a.get(), + stride_A, + DoGatherB ? tensor_b_gathered.get() : tensor_b.get(), + stride_B + }, + { + { alpha, beta }, + DoGatherC ? tensor_c_gathered.get() : tensor_c.get(), + stride_C, + DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(), + stride_D + }, + hw_info + }, + workspace_opt(GemmOpt::get_workspace_size(arguments_opt)) + { + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::device::BlockFillRandomUniform(tensor_a.get(), tensor_a.size(), 1, ElementA(7), ElementA(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_b.get(), tensor_b.size(), 1, ElementB(7), ElementB(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_c.get(), tensor_c.size(), 1, ElementC(7), ElementC(-8), 0); + cutlass::reference::device::BlockFillSequential(tensor_d.get(), tensor_d.size(), ElementD(0), ElementD(0)); + + // <- Fill gather_indices with unique random integers in range [0,n) + int index_range = GatherModeM ? problem_size_orig.m() : (GatherModeN ? problem_size_orig.n() : problem_size_orig.k()); + std::vector indices(index_range); + std::iota(indices.begin(), indices.end(), 0); + { // std::random_shuffle was deprecated in C++14 and removed in C++17 + std::random_device make_seed; + std::mt19937 source_of_randomness(make_seed()); + std::shuffle(indices.begin(), indices.end(), source_of_randomness); + } + gather_indices.copy_from_host(indices.data()); + + auto const gemm_init = [](auto & gemm, auto const & arguments, auto & workspace) + { + cutlass::Status status = gemm.can_implement(arguments); + CUTLASS_CHECK(status); + status = gemm.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + }; + + gemm_init(gemm, arguments, workspace ); + gemm_init(gemm_ref, arguments_ref, workspace_ref); + gemm_init(gemm_opt, arguments_opt, workspace_opt); + } + + void debug_output(std::ostream & os) + { + auto print_tensor = [](std::ostream &os, char const * name, auto const & data, auto shape, auto stride) + { + std::vector> h_data(data.size()); + data.copy_to_host(h_data.data()); + Tensor t = make_tensor(h_data.data(), shape, stride); + os << "\n" << name << ": " << std::setw(4) << t << std::endl; + }; + { + auto [M,N,K,L] = problem_shape_orig; + print_tensor(os, "A", tensor_a, make_shape(M,K,L), stride_A_orig); + print_tensor(os, "B", tensor_b, make_shape(N,K,L), stride_B_orig); + print_tensor(os, "C", tensor_c, make_shape(M,N,L), stride_C_orig); + print_tensor(os, "D", tensor_d, make_shape(M,N,L), stride_D_orig); + print_tensor(os, "D reference", tensor_d_reference, make_shape(M,N,L), stride_D_orig); + print_tensor(os, "indices", gather_indices, make_shape(gather_indices.size()), make_stride(_1{})); + } + } + + template + static void run_gemm(Gemm2 &gemm) + { + cutlass::Status status = gemm.run(); + CUTLASS_CHECK(status); + } + + template + void run_reference(Gemm2 &gemm) + { + // Convenience wrapper around calls to separate gather/scatter kernels + auto run_gather = [this](auto call, auto const & input, auto & output, auto gather_func, auto batch_size, auto stride) + { + [[maybe_unused]] auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + call(input.get(), + output.get(), + gather_func, + batch_size, + static_cast(input.size() / batch_size), + static_cast(output.size() / batch_size), + static_cast(get(stride)), + hw_info); + }; + + // Forward calls via lambda to avoid specifying template arguments + auto gather_call = [](auto&&... args){ gather(static_cast(args)...); }; + auto scatter_call = [](auto&&... args){ scatter(static_cast(args)...); }; + + if constexpr (DoGatherA) { + run_gather(gather_call, tensor_a, tensor_a_gathered, arguments.gather_A, problem_size.batch(), stride_A); + } + if constexpr (DoGatherB) { + run_gather(gather_call, tensor_b, tensor_b_gathered, arguments.gather_B, problem_size.batch(), stride_B); + } + if constexpr (DoGatherC) { + if (beta != ElementComputeEpilogue(0)) { + run_gather(gather_call, tensor_c, tensor_c_gathered, arguments.epilogue.gather_C, problem_size.batch(), stride_C); + } + } + + run_gemm(gemm); + + if constexpr (DoScatterD) { + run_gather(scatter_call, tensor_d_gathered, tensor_d_reference, arguments.epilogue.scatter_D, problem_size.batch(), stride_D); + } + } + + bool verify() + { + run_gemm(gemm); + run_reference(gemm_ref); + cudaDeviceSynchronize(); + return cutlass::reference::device::BlockCompareEqual(tensor_d.get(), tensor_d_reference.get(), tensor_d.size()); + } + + bool run(Options const &options) + { + if (options.reference_check) { + if (!verify()) { + std::cout << "Failed validation" << std::endl; +#if 1 + debug_output(std::cout); +#endif + return false; + } + else { + std::cout << "Passed validation" << std::endl; + } + } + + // + // Run profiling loop + // + + auto const benchmark = [&](auto name, auto func) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + func(); + } + timer.stop(); + + double runtime = timer.elapsed_millis() / double(options.iterations); + double gflops = 2 * double(problem_size.product()) / 1e6 / runtime; // Two flops per multiply-add + + std::cout << name << ":\n"; + std::cout << " Runtime: " << runtime << " ms\n"; + std::cout << " GFLOPs: " << gflops << "\n"; + }; + + benchmark("Fused", [&](){ run_gemm(gemm); }); + benchmark("Unfused default", [&](){ run_reference(gemm_ref); }); + benchmark("Unfused optimized", [&](){ run_reference(gemm_opt); }); + + return true; + } +}; + +} // namespace example + +int main(int argc, const char ** argv) { + + bool notSupported = false; + + // CUDA 12 minimum required + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA Toolkit version 12 or later.\n"; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (props.major < 9) { + std::cerr << "This example requires a device with compute capability 90 or higher.\n"; + notSupported = true; + } + + if (notSupported) { + return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems + } + + example::Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << "\n"; + return EXIT_SUCCESS; + } + + if (!options.valid()) { + std::cerr << "Invalid arguments." << "\n"; + return EXIT_FAILURE; + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool result = true; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + switch (options.mode) { + using namespace example; + case 0: { + std::cout << "Gather A,C + scatter D on M mode:" << std::endl; + using Runner = ExampleRunner< + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // A + cutlass::half_t, cutlass::layout::ColumnMajor, NoGather, // B + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // C + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // D + float, float>; + result &= Runner(options, hw_info).run(options); + break; + } + case 1: { + std::cout << "Gather B,C + scatter D on N mode:" << std::endl; + using Runner = ExampleRunner< + cutlass::half_t, cutlass::layout::RowMajor, NoGather, // A + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // B + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // C + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // D + float, float>; + result &= Runner(options, hw_info).run(options); + break; + } + case 2: { + std::cout << "Gather A,B on K mode:" << std::endl; + using Runner = ExampleRunner< + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // A + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // B + cutlass::half_t, cutlass::layout::RowMajor, NoGather, // C + cutlass::half_t, cutlass::layout::RowMajor, NoGather, // D + float, float>; + result &= Runner(options, hw_info).run(options); + break; + } + } +#endif + + return result ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/examples/52_hopper_gather_scatter_fusion/CMakeLists.txt b/examples/52_hopper_gather_scatter_fusion/CMakeLists.txt new file mode 100644 index 0000000000..ec5ed89fd6 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 52_hopper_gather_scatter_fusion + 52_hopper_gather_scatter_fusion.cu + ) diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp new file mode 100644 index 0000000000..458cb19554 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -0,0 +1,266 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/tensor.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class GatherA_, + class GatherB_, + class TileScheduler_ = void +> +class GemmGather +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using TileScheduleTag = TileScheduler_; + using TileScheduler = TileScheduler_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + static_assert(std::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + + using GatherA = GatherA_; + using GatherB = GatherB_; + + static constexpr int SharedStorageSize = static_cast(cute::max( + sizeof(typename CollectiveMainloop::SharedStorage), + sizeof(typename CollectiveEpilogue::SharedStorage))); + + static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + return { + args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + args.gather_A, + args.gather_B + }; + } + + static + Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + static + bool + can_implement(Arguments const& args) { + return args.mode == GemmUniversalMode::kGemm or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + static constexpr + dim3 + get_grid_shape(Params const& params) { + int batch_count = 1; + if constexpr (rank(ProblemShape{}) == 4) { + batch_count = cute::size<3>(params.problem_shape); + } + + return dim3( + cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), + cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), + batch_count + ); + } + + static constexpr + dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Preconditions + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + int thread_idx = int(threadIdx.x); + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + auto [m_coord, n_coord, l_coord] = blockIdx; + auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l) + + // Represent the full tensors + Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l) + Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l) + + // Get batch slice + Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) + Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) + + // Slice to get the tiles this thread block is responsible for + Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape + TiledMma tiled_mma; + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + clear(accumulators); + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + int k_tile_count = size<2>(gA); + + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + collective_mma( + accumulators, + gA, + gB, + accumulators, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + smem_buf + ); + + // Epilogue and write to gD + CollectiveEpilogue epilogue{params.epilogue}; + epilogue( + problem_shape_MNKL, + blk_shape, + blk_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + smem_buf + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh b/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh new file mode 100644 index 0000000000..8a044bb794 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/numeric/math.hpp" + +namespace example +{ + +// Naive grid-stride loop implementation of gather +template +__global__ void +gather_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_output; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[k] = input_b[i + func(j) * stride_divmod.divisor]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +gather(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + gather_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +// Naive grid-stride loop implementation of scatter +template +__global__ void +scatter_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_input; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[i + func(j) * stride_divmod.divisor] = input_b[k]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +scatter(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + scatter_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +} // namespace example diff --git a/examples/52_hopper_gather_scatter_fusion/gather_tensor.hpp b/examples/52_hopper_gather_scatter_fusion/gather_tensor.hpp new file mode 100644 index 0000000000..9caf0aa677 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/gather_tensor.hpp @@ -0,0 +1,209 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + print("Strided{"); + print(s.stride_); + print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + print("Custom{"); + print(s.func_); + print(","); + print(s.stride_); + print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp new file mode 100644 index 0000000000..f08c107c8c --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/int.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass::epilogue::collective { + +/// Applies an element wise operation to all elements within the fragment +/// and scatter-writes them out to destination storage. +/// GatherC and ScatterD are types of user-defined functions that apply the +/// transoformation of the strided coordinate (e.g. through an index array). +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_, + class GatherC_, + class ScatterD_ +> +class EpilogueGatherScatter { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + // Every epilogue needs these two GmemTiledCopy{C,D} aliases. + // If you don't know what they should be, just use void. + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using GatherC = GatherC_; + using ScatterD = ScatterD_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread_params{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + GatherC gather_C{}; + ScatterD scatter_D{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + EpilogueGatherScatter(Params const& params_) : params(params_) { } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + (void) smem_buf; + ThreadEpilogueOp epilogue_op{params.thread_params}; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l) + Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +} // namespace cutlass::epilogue::collective + diff --git a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu new file mode 100644 index 0000000000..8800615f8d --- /dev/null +++ b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu @@ -0,0 +1,975 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM+permute example. + + This example demonstrates the fusion of tensor permutation operations with a Hopper GEMM kernel. + It is similar in spirit to example 39_gemm_permute, but uses CUTLASS 3 CollectiveBuilder API to + construct kernels that make use of Hopper architecture features: Tensor Memory Accelerator (TMA) + units and warpgroup-level MMA instructions. + + Background + ---------- + + While a GEMM kernel computes a product of two matrices (rank-2 tensors), the source data may + come from higher-rank tensors by combining some if its modes (dimensions) into the row and column + modes of the matrix. These tensors are often outputs from previous layers of a network, and the + data may sometimes need to be reordered in memory before a GEMM is computed. Similarly, the output + of a GEMM may need to be reordered before a subsequent operation can be executed. + + Consider this sample PyTorch code: + + # Forward pass + D = torch.mm(A, B).view(M/D1, D1, D2, N/D2).permute(0, 2, 1, 3) + + # Backward pass + grad_A = torch.mm(grad_D.permute(0, 2, 1, 3).view(M, N), B) + + Executing the reordering as a separate operation requires committing intermediate tensor to memory + and increases the latency and memory footprint of the model. By fusing the permutation with either + reading of A/B matrices or writing of D matrix, we can avoid the unnecessary global memory traffic + and kernel launch overhead. + + Implementation + -------------- + + The approach relies on two things: + - The ability of CUTLASS 3 to naturally perform general tensor contractions (GETT) owing to the + flexibility of CuTe's hierarchical layouts (see example 51_hopper_gett for more details). + - The harware capabilities of Hopper TMA units that allow for loading multidimensional tensors with + (almost) arbitrary strides, which can be used to represent a permuted view of the data. + + In this example we reuse the permutation classes of examples 39_gemm_permute as operation tags. + For each tag, a specialization of struct PermuteTraits<> provides the necessary information about + the target tensor shape and ordering of modes. The main class, ExampleRunner, then figures out the + overall (hierarchical) shape of the GEMM operation and computes the shape and strides for each + tensor taking into account the permutation applied. We highlight the importance of specifying + consistent multidimensional shapes for all tensors (even those that are not permuted), as well as + choosing hierarchical GEMM tile sizes that best fit those shapes (in cases where some tensor + dimensions are known at compile time). + + In addition, this example implements a standalone permutation kernel that is used to both verify + correctness of the fused kernel and benchmark the fused kernel against an unfused version that + writes intermediate tensor to memory. +*/ + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "helper.h" +#include "permute_kernel.cuh" +#include "permute_traits.hpp" + +namespace example +{ + +struct Options { + + bool help; + + cutlass::gemm::BatchedGemmCoord problem_size; + + float alpha; + float beta; + + bool reference_check; + int iterations; + + bool verbose; + + Options(): + help(false), + problem_size({2048, 2048, 2048, 8}), + alpha(1.0), + beta(1.0), + reference_check(true), + iterations(20), + verbose(false) { } + + bool valid() const { + return problem_size.m() > 0 + && problem_size.n() > 0 + && problem_size.k() > 0 + && problem_size.batch() > 0 + && iterations > 0; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch_size", problem_size.batch()); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("check", reference_check, true); + cmd.get_cmd_line_argument("iterations", iterations); + + cmd.get_cmd_line_argument("verbose", verbose, false); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << + "53_hopper_gemm_permute example\n" + "\n" + " This example uses the CUTLASS Library to fuse permute() on input/output tensors with GEMM\n" + "\n" + "Options:\n" + " --help If specified, displays this usage statement.\n" + " --m= GEMM M dimension\n" + " --n= GEMM N dimension\n" + " --k= GEMM K dimension\n" + " --alpha= GEMM alpha parameter\n" + " --beta= GEMM beta parameter\n" + " --iterations= Number of profiling iterations to perform.\n" + " --check= Validate results against a reference (unfused) imlementation" + " --verbose= Enable verbose output" + "\n" + "Examples:\n" + "\n" + "$ ./examples/53_hopper_gemm_permute/53_hopper_gemm_permute --m=4096 --n=2048 --k=3072 --batch_size=8\n"; + + return out; + } +}; + +using namespace cute; + +// Check the shapes assigned to the same mode of different tensors, +// ensure all permuted shapes are the same and return that shape. +template +auto +select_mode_shape(Shapes const & ... shapes) { + auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) { + if constexpr (rank(shape) > 1) { + return cute::make_tuple(shape); + } + else { + return cute::make_tuple(); + } + }); + if constexpr (rank(permuted_shapes) == 0) { + return get<0>(cute::make_tuple(shapes...)); + } + else { + auto ref_shape = get<0>(permuted_shapes); + for_each(permuted_shapes, [&](auto shape) { + // This static assert fails to compile on GCC 7.5 + // static_assert(is_same::value, "Inconsistent shapes for the same mode"); + // This runtime check can be skipped if all permutations are required to be static. + if (shape != ref_shape) + { + print("Inconsistent shapes for the same mode: "); + print(ref_shape); print(" and "); print(shape); print("\n"); + exit(EXIT_FAILURE); + } + }); + return ref_shape; + } +} + +template +auto +compute_default_stride(Shape const & shape, StrideOrig const & stride_orig) { + // Only supports column-major and row-major, batch stride always comes last + if constexpr (is_constant<1, decltype(get<0>(stride_orig))>::value) { + return compact_col_major(shape); + } + else + { + return compact_order(shape, Step<_1,_0,_2>{}); + } +} + +// Divide a static scalar TileSize into static modes of Shape until either: +// - a dynamic mode is encountered +// - we run out of size to divide +// - no longer divisible by next shape +// Examples: +// select_tile_shape(_128, (_8,_16)) -> (_8,_16) +// select_tile_shape(_128, (_8,_32)) -> (_8,_16) +// select_tile_shape(_128, (_8, _4)) -> (_8,_4,_4) +// select_tile_shape(_128, (_8, 4)) -> (_8,_16) +template +auto +select_tile_shape(TileSize size, Shape const& shape) +{ + static_assert(is_static::value, "Tile size must be static"); + if constexpr (rank(Shape{}) == 0) { + return cute::make_tuple(size); + } + else { + if constexpr (is_static>::value) { + auto div = front(shape); + if constexpr (size > div and size % div == 0) { + return prepend(select_tile_shape(size / div, take<1,tuple_size_v>(shape)), div); + } + else { + return cute::make_tuple(size); + } + } + else { + return cute::make_tuple(size); + } + } +} + +template +class ExampleRunner +{ +private: + + // Define shapes for each operand and original GEMM problem as a whole. + + using MatrixShape = Shape; // [M,N,L]/[M,K,L]/[N,K,L] + using ProblemShape = Shape; // [M,N,K,L] + + // Determine the CuTe stride for each of the four operands. + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + // Flags to check which operands will be permuted. + + static constexpr bool DoPermuteA = not cutlass::layout::is_trivial_permute; + static constexpr bool DoPermuteB = not cutlass::layout::is_trivial_permute; + static constexpr bool DoPermuteC = not cutlass::layout::is_trivial_permute; + static constexpr bool DoPermuteD = not cutlass::layout::is_trivial_permute; + + // For input operands, we must use inverse of the permutation operation + // to read data that is stored in original (un-permuted) order. + + using PermuteAReal = typename cutlass::layout::InversePermute::type; + using PermuteBReal = typename cutlass::layout::InversePermute::type; + using PermuteCReal = typename cutlass::layout::InversePermute::type; + using PermuteDReal = PermuteD; + + // Get permutation layout for each operand. + // A permutation layout is a rank-3 layout in the usual CuTe mode ordering, + // but each mode may have a nested shape corresponding to the reshaping of + // the matrix into a multidimensional tensor, and the strides are computed + // taking the desired permutation into account. + + template + using LayoutPermute = remove_cvref_t(make_layout(MatrixShape{}, Stride{})))>; + + using LayoutAP = LayoutPermute; + using LayoutBP = LayoutPermute; + using LayoutCP = LayoutPermute; + using LayoutDP = LayoutPermute; + + // Now we want to build the unified problem shape for permute-GEMM. + // To do this, we check the corresponding mode in each tensor that has it. + // If at least one tensor has a mode that has been reshaped (i.e. rank > 1), + // its shape will be used as the reference shape for that mode in all tensors. + // If multiple tensors have reshaped mode, we additionally check that their + // shapes for that mode match. Otherwise, we can't define a consistent GEMM shape. + + using ShapeM = decltype(select_mode_shape(shape<0>(LayoutAP{}), shape<0>(LayoutCP{}), shape<0>(LayoutDP{}))); + using ShapeN = decltype(select_mode_shape(shape<0>(LayoutBP{}), shape<1>(LayoutCP{}), shape<1>(LayoutDP{}))); + using ShapeK = decltype(select_mode_shape(shape<1>(LayoutAP{}), shape<1>(LayoutBP{}))); + using ShapeL = decltype(select_mode_shape(shape<2>(LayoutAP{}), shape<2>(LayoutBP{}), shape<2>(LayoutCP{}), shape<2>(LayoutDP{}))); + + using ProblemShapePermute = Shape; + + using ShapeAPermute = Shape; + using ShapeBPermute = Shape; + using ShapeCPermute = Shape; + using ShapeDPermute = Shape; + + // Next, we must define the strides for each tensor. + // If the tensor is permuted, we take the strides produced by the permutation function. + // Otherwise, we compute default strides induced by the new (multidimensional) shape of the tensor. + // + // This won't always work in general if multiple tensors are permuted: e.g. if PermuteA affects + // modes M and K, and PermuteB affects modes N and L, the single stride for mode L of tensor A + // computed by PermuteA will be non-congruent with it's shape that is changed by PermuteB. + // To handle this correctly, a more complicated logic is needed to reconstruct multi-mode strides. + // This is not addressed here, as it's not a common requirement to permute multiple tensors in one GEMM. + + using StrideAPermute = conditional_t, decltype(compute_default_stride(ShapeAPermute{}, StrideA{}))>; + using StrideBPermute = conditional_t, decltype(compute_default_stride(ShapeBPermute{}, StrideB{}))>; + using StrideCPermute = conditional_t, decltype(compute_default_stride(ShapeCPermute{}, StrideC{}))>; + using StrideDPermute = conditional_t, decltype(compute_default_stride(ShapeDPermute{}, StrideD{}))>; + + // We need to select optimal tile shape based on the tile size specified by the user. + // This is done by dividing the tile size in each mode by the mode shape as much + // as possible (i.e. until we run out of tile size or encounter a dynamic sub-shape). + + using TileMPermute = decltype(select_tile_shape(get<0>(TileShape{}), ShapeM{})); + using TileNPermute = decltype(select_tile_shape(get<1>(TileShape{}), ShapeN{})); + using TileKPermute = decltype(select_tile_shape(get<2>(TileShape{}), ShapeK{})); + + using TileShapePermute = Shape; + + // Now we are ready to define the GEMM kernel types for both fused permute and reference paths. + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementEpilogue, + ElementC, StrideC, 128 / cutlass::sizeof_bits::value, + ElementD, StrideD, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveEpiloguePermute = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShapePermute, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementEpilogue, + ElementC, StrideCPermute, 128 / cutlass::sizeof_bits::value, + ElementD, StrideDPermute, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 128 / cutlass::sizeof_bits::value, + ElementB, StrideB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveMainloopPermute = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideAPermute, 128 / cutlass::sizeof_bits::value, + ElementB, StrideBPermute, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShapePermute, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using GemmKernelPermute = cutlass::gemm::kernel::GemmUniversal< + ProblemShapePermute, + CollectiveMainloopPermute, + CollectiveEpiloguePermute + >; + + using GemmReference = cutlass::gemm::device::GemmUniversalAdapter; + using GemmPermute = cutlass::gemm::device::GemmUniversalAdapter; + + // Data members + + cutlass::gemm::BatchedGemmCoord problem_size; + ProblemShape problem_shape; + cutlass::KernelHardwareInfo hw_info; + + ElementEpilogue alpha; + ElementEpilogue beta; + + MatrixShape shape_A; + MatrixShape shape_B; + MatrixShape shape_C; + MatrixShape shape_D; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + + LayoutAP layout_AP; + LayoutBP layout_BP; + LayoutCP layout_CP; + LayoutDP layout_DP; + + ShapeM shape_M; + ShapeN shape_N; + ShapeK shape_K; + ShapeL shape_L; + + ProblemShapePermute problem_shape_permute; + + ShapeAPermute shape_A_permute; + ShapeBPermute shape_B_permute; + ShapeCPermute shape_C_permute; + ShapeDPermute shape_D_permute; + + StrideAPermute stride_A_permute; + StrideBPermute stride_B_permute; + StrideCPermute stride_C_permute; + StrideDPermute stride_D_permute; + + cutlass::device_memory::allocation tensor_a; + cutlass::device_memory::allocation tensor_b; + cutlass::device_memory::allocation tensor_c; + cutlass::device_memory::allocation tensor_d; + + cutlass::device_memory::allocation tensor_a_permuted; + cutlass::device_memory::allocation tensor_b_permuted; + cutlass::device_memory::allocation tensor_c_permuted; + cutlass::device_memory::allocation tensor_d_unpermuted; + cutlass::device_memory::allocation tensor_d_reference; + + cutlass::gemm::GemmUniversalMode gemm_mode; + + GemmPermute gemm_permute; + typename GemmPermute::Arguments arguments_permute; + cutlass::device_memory::allocation workspace_permute; + + GemmReference gemm_reference; + typename GemmReference::Arguments arguments_reference; + cutlass::device_memory::allocation workspace_reference; + + public: + + ExampleRunner(Options const & options, cutlass::KernelHardwareInfo const & hw_info) + : problem_size(options.problem_size), + problem_shape(problem_size.m(), problem_size.n(), problem_size.k(), problem_size.batch()), + hw_info(hw_info), + alpha(options.alpha), + beta(options.beta), + shape_A(make_shape(problem_size.m(), problem_size.k(), problem_size.batch())), + shape_B(make_shape(problem_size.n(), problem_size.k(), problem_size.batch())), + shape_C(make_shape(problem_size.m(), problem_size.n(), problem_size.batch())), + shape_D(make_shape(problem_size.m(), problem_size.n(), problem_size.batch())), + stride_A(cutlass::make_cute_packed_stride(StrideA{}, shape_A)), + stride_B(cutlass::make_cute_packed_stride(StrideB{}, shape_B)), + stride_C(cutlass::make_cute_packed_stride(StrideC{}, shape_C)), + stride_D(cutlass::make_cute_packed_stride(StrideD{}, shape_D)), + layout_AP(make_permute_layout(make_layout(shape_A, stride_A))), + layout_BP(make_permute_layout(make_layout(shape_B, stride_B))), + layout_CP(make_permute_layout(make_layout(shape_C, stride_C))), + layout_DP(make_permute_layout(make_layout(shape_D, stride_D))), + shape_M(select_mode_shape(shape<0>(layout_AP), shape<0>(layout_CP), shape<0>(layout_DP))), + shape_N(select_mode_shape(shape<0>(layout_BP), shape<1>(layout_CP), shape<1>(layout_DP))), + shape_K(select_mode_shape(shape<1>(layout_AP), shape<1>(layout_BP))), + shape_L(select_mode_shape(shape<2>(layout_AP), shape<2>(layout_BP), shape<2>(layout_CP), shape<2>(layout_DP))), + problem_shape_permute(shape_M, shape_N, shape_K, shape_L), + shape_A_permute(make_shape(shape_M, shape_K, shape_L)), + shape_B_permute(make_shape(shape_N, shape_K, shape_L)), + shape_C_permute(make_shape(shape_M, shape_N, shape_L)), + shape_D_permute(make_shape(shape_M, shape_N, shape_L)), + stride_A_permute(conditional_return(layout_AP.stride(), compute_default_stride(shape_A_permute, stride_A))), + stride_B_permute(conditional_return(layout_BP.stride(), compute_default_stride(shape_B_permute, stride_B))), + stride_C_permute(conditional_return(layout_CP.stride(), compute_default_stride(shape_C_permute, stride_C))), + stride_D_permute(conditional_return(layout_DP.stride(), compute_default_stride(shape_D_permute, stride_D))), + tensor_a(problem_size.m() * problem_size.k() * problem_size.batch()), + tensor_b(problem_size.k() * problem_size.n() * problem_size.batch()), + tensor_c(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_d(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_a_permuted(problem_size.m() * problem_size.k() * problem_size.batch()), + tensor_b_permuted(problem_size.k() * problem_size.n() * problem_size.batch()), + tensor_c_permuted(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_d_unpermuted(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_d_reference(problem_size.m() * problem_size.n() * problem_size.batch()), + gemm_mode(problem_size.batch() > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), + arguments_permute{ + gemm_mode, + problem_shape_permute, + { + tensor_a.get(), stride_A_permute, + tensor_b.get(), stride_B_permute, + }, + { + { alpha, beta }, + tensor_c.get(), stride_C_permute, + tensor_d.get(), stride_D_permute + }, + hw_info + }, + workspace_permute(GemmPermute::get_workspace_size(arguments_permute)), + arguments_reference{ + gemm_mode, + problem_shape, + { + DoPermuteA ? tensor_a_permuted.get() : tensor_a.get(), stride_A, + DoPermuteB ? tensor_b_permuted.get() : tensor_b.get(), stride_B + }, + { + { alpha, beta }, + DoPermuteC ? tensor_c_permuted.get() : tensor_c.get(), stride_C, + DoPermuteD ? tensor_d_unpermuted.get() : tensor_d_reference.get(), stride_D + }, + hw_info + }, + workspace_reference(GemmReference::get_workspace_size(arguments_reference)) + { + if (options.verbose) { + print("Original GEMM problem:\n"); + print(" Problem shape: "); print(problem_shape); print("\n"); + print(" Layout A: "); print(make_layout(shape_A, stride_A)); print("\n"); + print(" Layout B: "); print(make_layout(shape_B, stride_B)); print("\n"); + print(" Layout C: "); print(make_layout(shape_C, stride_C)); print("\n"); + print(" Layout D: "); print(make_layout(shape_D, stride_D)); print("\n"); + print(" Tile shape: "); print(TileShape{}); print("\n"); + print("With fused permutations:\n"); + print(" Problem shape: "); print(problem_shape_permute); print("\n"); + print(" Layout A: "); print(make_layout(shape_A_permute, stride_A_permute)); print("\n"); + print(" Layout B: "); print(make_layout(shape_B_permute, stride_B_permute)); print("\n"); + print(" Layout C: "); print(make_layout(shape_C_permute, stride_C_permute)); print("\n"); + print(" Layout D: "); print(make_layout(shape_D_permute, stride_D_permute)); print("\n"); + print(" Tile shape: "); print(TileShapePermute{}); print("\n"); + } + + cutlass::reference::device::BlockFillRandomUniform(tensor_a.get(), tensor_a.size(), 1, ElementA(7), ElementA(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_b.get(), tensor_b.size(), 2, ElementB(7), ElementB(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_c.get(), tensor_c.size(), 3, ElementC(7), ElementC(-8), 0); + cutlass::reference::device::BlockFillSequential(tensor_d.get(), tensor_d.size(), ElementD(0), ElementD(0)); + + auto const gemm_init = [](auto & gemm, auto const & arguments, auto & workspace) { + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Requested GEMM kernel cannot be used for this problem.\n" + << "Check problem sizes and alignment requirements." << std::endl; + exit(EXIT_FAILURE); + } + status = gemm.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + }; + + gemm_init(gemm_permute, arguments_permute, workspace_permute ); + gemm_init(gemm_reference, arguments_reference, workspace_reference); + } + + void debug_output(std::ostream & os) + { + auto print_tensor = [](std::ostream &os, char const * name, auto const & data, auto shape, auto stride) + { + std::vector> h_data(data.size()); + data.copy_to_host(h_data.data()); + Tensor t = make_tensor(h_data.data(), shape, stride); + os << "\n" << name << ": " << std::setw(4) << t << std::endl; + }; + auto [M,N,K,L] = problem_shape; + print_tensor(os, "A", tensor_a, make_shape(M,K,L), stride_A); + print_tensor(os, "B", tensor_b, make_shape(N,K,L), stride_B); + print_tensor(os, "C", tensor_c, make_shape(M,N,L), stride_C); + print_tensor(os, "D", tensor_d, make_shape(M,N,L), stride_D); + print_tensor(os, "D reference", tensor_d_reference, make_shape(M,N,L), stride_D); + } + + template + static float + run_gemm(Gemm &gemm) + { + GpuTimer timer; + if constexpr (DoTime) timer.start(); + cutlass::Status status = gemm.run(); + CUTLASS_CHECK(status); + if constexpr (DoTime) timer.stop(); + if constexpr (DoTime) return timer.elapsed_millis(); + else return 0; + } + + template + static float + run_permute(cutlass::device_memory::allocation const & input, + cutlass::device_memory::allocation & output, + Layout const& layout, + cutlass::KernelHardwareInfo const & hw_info) + { + auto idx = find_if(layout.stride(), [](auto x){ return not is_constant<1, decltype(x)>{}; }); + auto stride = get(layout.stride()); + + GpuTimer timer; + if constexpr (DoTime) timer.start(); + permute::kBatched, Permute>(input.get(), + output.get(), + size(take<0,2>(layout)), + static_cast(stride), + shape<2>(layout), + hw_info); + if constexpr (DoTime) timer.stop(); + if constexpr (DoTime) return timer.elapsed_millis(); + else return 0; + }; + + template + auto run_reference(Gemm2 &gemm) + { + float permute_time = 0.f; + if constexpr (DoPermuteA) { + auto orig_layout = make_original_layout(make_layout(shape_A, stride_A)); + permute_time += run_permute(tensor_a, tensor_a_permuted, orig_layout, hw_info); + } + if constexpr (DoPermuteB) { + auto orig_layout = make_original_layout(make_layout(shape_B, stride_B)); + permute_time += run_permute(tensor_b, tensor_b_permuted, select<1,0,2>(orig_layout), hw_info); + } + if constexpr (DoPermuteC) { + auto orig_layout = make_original_layout(make_layout(shape_C, stride_C)); + permute_time += run_permute(tensor_c, tensor_c_permuted, orig_layout, hw_info); + } + + float gemm_time = run_gemm(gemm); + + if constexpr (DoPermuteD) { + auto orig_layout = make_layout(shape_D, stride_D); + permute_time += run_permute(tensor_d_unpermuted, tensor_d_reference, orig_layout, hw_info); + } + + return cute::make_tuple(gemm_time, permute_time); + } + + bool verify() + { + run_gemm(gemm_permute); + run_reference(gemm_reference); + return cutlass::reference::device::BlockCompareEqual(tensor_d.get(), tensor_d_reference.get(), tensor_d.size()); + } + + bool run(Options const &options) + { + if (options.reference_check) { + if (!verify()) { + std::cout << "Failed validation" << std::endl; +#if 1 + debug_output(std::cout); +#endif + return false; + } + else { + std::cout << "Passed validation" << std::endl; + } + } + + // + // Run profiling loop + // + + auto const benchmark = [&](auto name, auto func) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + func(); + } + timer.stop(); + + double runtime = timer.elapsed_millis() / double(options.iterations); + double gflops = 2 * double(problem_size.product()) / 1e6 / runtime; // Two flops per multiply-add + + std::cout << name << ":\n"; + std::cout << " Runtime: " << runtime << " ms\n"; + std::cout << " GFLOPs: " << gflops << "\n"; + }; + + benchmark("Fused GEMM+permute", [&](){ run_gemm(gemm_permute); }); + benchmark("Unfused GEMM+permute", [&](){ run_reference(gemm_reference); }); + benchmark("Standalone GEMM only", [&](){ run_gemm(gemm_reference); }); + std::cout << "\n"; + + return true; + } +}; + +} // namespace example + + +int main(int argc, char const **argv) +{ + bool notSupported = false; + + // CUDA 12 minimum required + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA Toolkit version 12 or later.\n"; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (props.major < 9) { + std::cerr << "This example requires a device with compute capability 90 or higher.\n"; + notSupported = true; + } + + if (notSupported) { + return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems + } + + example::Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << "\n"; + return EXIT_SUCCESS; + } + + if (!options.valid()) { + std::cerr << "Invalid arguments." << "\n"; + return EXIT_FAILURE; + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using namespace cute; + + // Define the data types + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementEpilogue = float; + + // M=64 for TMA epilogue + using TileShape = Shape<_128,_128,_64>; + + // Cluster launch with TMA multicast for better perf + using ClusterShape = Shape<_2,_2,_1>; + + bool result = true; + +#define COMPILE_ALL_EXAMPLES 0 + + // REGULAR GEMMS + + { + print("===================================================\n"); + print("Tensor A: RowMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor A: ColumnMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor B: RowMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + { + print("===================================================\n"); + print("Tensor B: ColumnMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor D: ColumnMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + { + print("===================================================\n"); + print("Tensor A: RowMajor, Tensor5DPermute20314<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor A: ColumnMajor, Tensor5DPermute02413<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor5DPermute20314<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor D: ColumnMajor, Tensor5DPermute02413<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + + // BATCHED GEMMS + + { + print("===================================================\n"); + print("Tensor A: RowMajor, Tensor4DPermuteBMM0213<8>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor4DPermuteBMM0213<8>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor A: ColumnMajor, Tensor4DPermuteBMM0321<8>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor4DPermuteBMM0321<8>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + + return result ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/examples/53_hopper_gemm_permute/CMakeLists.txt b/examples/53_hopper_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000..c831ac6a68 --- /dev/null +++ b/examples/53_hopper_gemm_permute/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 53_hopper_gemm_permute + 53_hopper_gemm_permute.cu + ) + diff --git a/examples/53_hopper_gemm_permute/permute_kernel.cuh b/examples/53_hopper_gemm_permute/permute_kernel.cuh new file mode 100644 index 0000000000..2d022af15b --- /dev/null +++ b/examples/53_hopper_gemm_permute/permute_kernel.cuh @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple permutation kernel implementation. +*/ + +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" +#include "cute/numeric/uint128.hpp" + +namespace example +{ + +/** + * Assumes column-major input (M mode is contiguous, N mode is strided). + * For row major, the inputs must be switched accordingly. +*/ +template +__global__ void +permute_kernel(Element const* __restrict__ input, + Element* __restrict__ output, + Permute permute, + int64_t num_elems, + cutlass::FastDivmod stride_divmod) +{ + // CUTLASS 2.x batched permute functions assume 0 batch stride for target tensor + Element const * input_b = input + blockIdx.z * num_elems; + Element * output_b = output + (Batched ? 0 : blockIdx.z * num_elems); + for (int64_t k = threadIdx.x + blockIdx.x * blockDim.x; k < num_elems; k += blockDim.x * gridDim.x) + { + int i, j; + stride_divmod(j, i, k); + output_b[permute(cutlass::PitchLinearCoord(i, j))] = input_b[i + j * stride_divmod.divisor]; + } +} + +template +void permute(Element const* input, + Element * output, + int64_t num_elems, + int stride, + int batch_count, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int64_t num_elems_upcast = num_elems / factor; + Permute permute_upcast(cutlass::PitchLinearCoord(stride_upcast, int(num_elems_upcast/stride_upcast)), stride_upcast); + + cutlass::FastDivmod stride_divmod(stride); + dim3 blocks(hw_info.sm_count, 1, batch_count); + permute_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + permute_upcast, + num_elems_upcast, + stride_upcast); +} + +} // namespace example diff --git a/examples/53_hopper_gemm_permute/permute_traits.hpp b/examples/53_hopper_gemm_permute/permute_traits.hpp new file mode 100644 index 0000000000..55c7641853 --- /dev/null +++ b/examples/53_hopper_gemm_permute/permute_traits.hpp @@ -0,0 +1,273 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Additional permutation information for the example. +*/ + +#include "cutlass/layout/permute.h" +#include "cutlass/gemm/gemm.h" + +namespace example +{ + +using namespace cute; + +// This struct is specialized below for different CUTLASS 2.x permutation ops +// to describe the operation in terms of target CuTe shape and stride order. +template +struct PermuteTraits {}; + +// Use X as a placeholder for shape division result +using X = Underscore; + +// Reshape a rank-2 shape into a multidimensional shape. +// Input: +// shape = (A, B, ...) +// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...) +// Output: +// ((A1, ..., A/prod(A1..Am), ..., Am), (B1, ..., B/prod(B1..Bn), ..., Bn), ...) +template +constexpr auto +reshape(Shape const& shape, TargetShape const& target_shape) +{ + if constexpr (is_tuple::value) { + return cute::transform(shape, target_shape, [](auto && s, auto && t){ return reshape(s, t); }); + } + else { + auto idx = find_if(target_shape, [](auto x){ return is_underscore{}; }); + constexpr int I = decltype(idx)::value; + static_assert(I < tuple_size_v, "Each mode of TargetShape must contain a placeholder X"); + auto divisors = remove(target_shape); + assert(shape % product(divisors) == 0); + return replace(target_shape, shape / product(divisors)); + } +} + +// Given a tensor layout, compute a permutation layout consisting of: +// - sub-modes corresponding to the implied multidimensional shape of the source tensor +// - strides accounting for the permutation operation being performed +template +constexpr auto +make_permute_layout(Layout const& layout) { + static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_permute_layout(select<1,0,2>(layout))); + } + else { + if constexpr (cutlass::layout::is_trivial_permute) { + // Special case for NoPermute. Use a depth-2 layout for consistency with other permutations. + using ShapeProfile = tuple, tuple, tuple>; + return unflatten(layout, ShapeProfile{}); + } + else { + // Here's where the permutation layout is actually built + using ShapeProfile = typename PermuteTraits::ShapeProfile; + using StrideOrder = typename PermuteTraits::StrideOrder; + return make_ordered_layout(reshape(layout.shape(), ShapeProfile{}), StrideOrder{}); + } + } +} + +namespace detail +{ + +template +struct is_constant_pred { + template + constexpr auto operator()(T) { + return is_constant{}; + } +}; + +template +constexpr auto +inverse_impl(Permutation const & perm, seq) { + return cute::make_tuple(Int{})>{}...); +} + +} // namespace detail + +// Compute an inverse of a permutation represented as a tuple of cute::Int<> +template +constexpr auto +inverse(Permutation const & perm) { + auto flat_perm = flatten(perm); + return unflatten(detail::inverse_impl(flat_perm, tuple_seq{}), perm); +} + +template +using inverse_t = decltype(inverse(T{})); + +// Given a rank-2 layout of tensor that is assumed to have been permuted, +// compute the original rank-2 layout of the tensor prior to the permutation. +// This is needed to form the correct input to the standalone permutation kernel. +template +constexpr auto +make_original_layout(Layout const& layout) { + static_assert(rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_original_layout(select<1,0,2>(layout))); + } + else { + using ShapeProfile = typename PermuteTraits::ShapeProfile; + using IndexOrder = typename PermuteTraits::IndexOrder; + using OrigOrder = conditional_t(), seq<0,1,2>, seq<1,0,2>>; + auto orig_shape = select(flatten(reshape(layout.shape(), ShapeProfile{})), IndexOrder{}); + // print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n"); + // print("Original shape: "); print(orig_shape); print("\n"); + return make_ordered_layout(product_each(orig_shape), OrigOrder{}); + } +} + +/////////////// Tensor4DPermute0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +/////////////// Tensor4DPermuteBMM0321 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1>, Step<_3>>; + using StrideOrder = Step, Step<_2>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape>, Shape, Shape>; + using IndexOrder = Step, Step<_2>, Step<_1,_3>>; + using StrideOrder = Step, Step<_1>, Step<_3>>; +}; + +/////////////// Tensor4DPermuteBMM0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1,_2>, Step<_3>>; + using StrideOrder = Step, Step<_0>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape>, Shape>; + using IndexOrder = Step, Step<_1>, Step<_2,_3>>; + using StrideOrder = Step, Step<_0,_2>, Step<_3>>; +}; + +/////////////// Tensor5DPermute02413 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int,X>, Shape>; + using IndexOrder = Step, Step<_4,_1,_3>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_1,_4,_2>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_1,_4,_2>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_4,_1,_3>, Step<_5>>; +}; + +/////////////// Tensor5DPermute20314 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_3,_1,_4>, Step<_5>>; + using StrideOrder = Step, Step<_0,_2,_4>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_2,_4,_1>, Step<_5>>; + using StrideOrder = Step, Step<_0,_3,_1>, Step<_5>>; +}; + +} // namespace example diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu new file mode 100644 index 0000000000..f6291c6e7f --- /dev/null +++ b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu @@ -0,0 +1,577 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example demonstrate a simple way to instantiate and run a FP8 GEMM using the new CUTLASS 3.0 + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. + + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + + 4. This example shows all important fusions used by FP8 gemm kernels, + i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor. + + Examples: + + $ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048 +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "helper.h" +#include "hopper_fp8_commandline.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Auxiliary matrix configuration +using ElementAux = ElementC; +using LayoutAux = LayoutC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_128,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; +using ElementAmax = typename EpilogueOutputOp::ElementAmax; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; +using StrideAux = StrideD; + +constexpr bool IsDFp8 = + cute::is_same_v or + cute::is_same_v; + +constexpr bool IsAuxFp8 = + cute::is_same_v or + cute::is_same_v; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +StrideAux stride_aux; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; +cutlass::HostTensor tensor_aux; +cutlass::HostTensor tensor_ref_aux; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; +cutlass::HostTensor scalar_alpha; +cutlass::HostTensor scalar_beta; +cutlass::HostTensor scale_A; +cutlass::HostTensor scale_B; +cutlass::HostTensor scale_C; +cutlass::HostTensor scale_D; +cutlass::HostTensor scale_aux; +cutlass::HostTensor abs_max_D; +cutlass::HostTensor reference_abs_max_D; +cutlass::HostTensor abs_max_aux; +cutlass::HostTensor reference_abs_max_aux; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + stride_aux = stride_D; + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + initialize_tensor(tensor_C.host_view(), seed + 2024); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + if (options.save_aux) { + tensor_aux.resize(c_coord); + tensor_aux.sync_device(); + tensor_ref_aux.resize(c_coord); + } + + if (options.device_scale) { + scalar_alpha.resize(cutlass::make_Coord(1)); + scalar_beta.resize(cutlass::make_Coord(1)); + scale_A.resize(cutlass::make_Coord(1)); + scale_B.resize(cutlass::make_Coord(1)); + scale_C.resize(cutlass::make_Coord(1)); + scale_D.resize(cutlass::make_Coord(1)); + scale_aux.resize(cutlass::make_Coord(1)); + + cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha); + cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta); + cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a); + cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b); + cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c); + cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d); + cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux); + + scalar_alpha.sync_device(); + scalar_beta.sync_device(); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + scale_aux.sync_device(); + } + + if (IsDFp8 && options.save_amax) { + abs_max_D.resize(cutlass::make_Coord(1)); + abs_max_D.sync_device(); + reference_abs_max_D.resize(cutlass::make_Coord(1)); + } + + if (IsAuxFp8 && options.save_aux && options.save_amax) { + abs_max_aux.resize(cutlass::make_Coord(1)); + abs_max_aux.sync_device(); + reference_abs_max_aux.resize(cutlass::make_Coord(1)); + } +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = scalar_alpha.device_data(); + fusion_args.beta_ptr = scalar_beta.device_data(); + fusion_args.scale_a = options.scale_a; + fusion_args.scale_b = options.scale_b; + fusion_args.scale_c = options.scale_c; + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + + // ignored if tensor types are not fp8 + fusion_args.scale_d = options.scale_d; + fusion_args.scale_aux = options.scale_aux; + fusion_args.scale_d_ptr = scale_D.device_data(); + fusion_args.scale_aux_ptr = scale_aux.device_data(); + + // leaving/setting these as nullptr disables the fusion at runtime + fusion_args.bias_ptr = nullptr; + + if (options.save_aux) { + fusion_args.aux_ptr = tensor_aux.device_data(); + fusion_args.dAux = stride_aux; + if (options.save_amax) { + fusion_args.amax_aux_ptr = abs_max_aux.device_data(); + } + } + + if (options.save_amax) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto Aux = cute::make_tensor(tensor_ref_aux.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_aux)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + decltype(Aux), + unused_t, // valpha + unused_t, // vbeta + ActivationFunctor + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.Aux = Aux; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + epilogue_params.scale_a = options.scale_a; + epilogue_params.scale_b = options.scale_b; + epilogue_params.scale_c = options.scale_c; + epilogue_params.scale_d = options.scale_d; + epilogue_params.scale_aux = options.scale_aux; + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data(); + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + if (IsDFp8 && options.save_amax) { + abs_max_D.sync_host(); + passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0)); + } + + if (options.save_aux) { + tensor_aux.sync_host(); + passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view()); + if (IsAuxFp8 && options.save_amax) { + abs_max_aux.sync_host(); + passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0)); + } + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/CMakeLists.txt b/examples/54_hopper_fp8_warp_specialized_gemm/CMakeLists.txt new file mode 100644 index 0000000000..4ea4c23571 --- /dev/null +++ b/examples/54_hopper_fp8_warp_specialized_gemm/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 54_hopper_fp8_warp_specialized_gemm + 54_hopper_fp8_warp_specialized_gemm.cu + ) diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000..e465d43f84 --- /dev/null +++ b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp @@ -0,0 +1,109 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; + bool device_scale = false; + bool save_aux = true; + bool save_amax = true; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("scale_a", scale_a, 1.f); + cmd.get_cmd_line_argument("scale_b", scale_b, 1.f); + cmd.get_cmd_line_argument("scale_c", scale_c, 1.f); + cmd.get_cmd_line_argument("scale_d", scale_d, 1.f); + cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f); + cmd.get_cmd_line_argument("device_scale", device_scale, false); + cmd.get_cmd_line_argument("save_aux", save_aux, true); + cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "52_fp8_hopper_warp_specialized_gemm\n\n" + << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --scale_a= Scaling factor for A\n" + << " --scale_b= Scaling factor for B\n" + << " --scale_c= Scaling factor for C\n" + << " --scale_d= Scaling factor for D (ignored for non-fp8 D)\n" + << " --scale_aux= Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n" + << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" + << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" + << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "52_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index aea1a89f5b..cf604d861f 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -132,6 +132,9 @@ foreach(EXAMPLE 49_hopper_gemm_with_collective_builder 50_hopper_gemm_with_epilogue_swizzle 51_hopper_gett + 52_hopper_gather_scatter_fusion + 53_hopper_gemm_permute + 54_hopper_fp8_warp_specialized_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/examples/python/00_basic_gemm.ipynb b/examples/python/00_basic_gemm.ipynb index f69f4d6e29..65c1107fe6 100644 --- a/examples/python/00_basic_gemm.ipynb +++ b/examples/python/00_basic_gemm.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "1ef96b3f", "metadata": {}, @@ -12,6 +13,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "962324fd", "metadata": {}, @@ -31,8 +33,8 @@ "\n", "import cutlass\n", "\n", - "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", - "# omit this information.\n", + "# This controls whether the C++ GEMM declaration will be printed at each step. \n", + "# Set to `False` to omit this information.\n", "print_module = True\n", "\n", "m = 128\n", @@ -60,6 +62,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "f2c7bf48", "metadata": {}, @@ -87,6 +90,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "4a5856de", "metadata": {}, @@ -95,6 +99,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "945478ef", "metadata": {}, @@ -114,6 +119,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "ee5cbbbe", "metadata": {}, @@ -122,6 +128,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "b6c86493", "metadata": {}, @@ -143,6 +150,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "6d27c575", "metadata": {}, @@ -167,6 +175,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "639dcb59", "metadata": {}, @@ -185,6 +194,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "0cce1eae", "metadata": {}, @@ -219,6 +229,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "52a4e318", "metadata": {}, @@ -245,6 +256,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "dc3ad875", "metadata": {}, @@ -267,6 +279,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "c5a8b534", "metadata": {}, @@ -281,13 +294,23 @@ "metadata": {}, "outputs": [], "source": [ - "# Stream K is only supported pre-SM90 (at least when this example was written)\n", - "if plan.cc != 90:\n", + "# Stream K is exposed through the threadblock swizzle method for pre-SM90 kernels,\n", + "# and via the tile_scheduler attribute of the TileDescription for post-SM90 kernels\n", + "if plan.cc < 90:\n", " plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n", + " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)\n", + "else:\n", + " # Stream-K is currently only supported for warp-specialized cooperative kernels\n", + " td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative\n", + " td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative\n", + " td.tile_scheduler = cutlass.TileSchedulerType.StreamK\n", + "\n", + " plan.compile(td)\n", " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" ] }, { + "attachments": {}, "cell_type": "markdown", "id": "5a8ba2ba", "metadata": {}, @@ -327,7 +350,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.6.9" }, "vscode": { "interpreter": { diff --git a/examples/python/01_epilogue.ipynb b/examples/python/01_epilogue.ipynb index 05ab60d620..f7abddd886 100644 --- a/examples/python/01_epilogue.ipynb +++ b/examples/python/01_epilogue.ipynb @@ -102,7 +102,7 @@ "outputs": [], "source": [ "tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n", - "plan.activation = cutlass.epilogue.relu\n", + "plan.activation = \"relu\"\n", "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)" ] }, @@ -169,13 +169,25 @@ " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" ] }, + { + "cell_type": "markdown", + "id": "18828622", + "metadata": {}, + "source": [ + "To add an activation with parameter such as `leaky_relu`, a tuple should be provided containing the activation function name and the (or a list of) parameter." + ] + }, { "cell_type": "code", "execution_count": null, - "id": "751f8d92", + "id": "53108eae", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "negative_slope = 0.5\n", + "plan.activation = (\"leaky_relu\", negative_slope)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] } ], "metadata": { diff --git a/examples/python/02_pytorch_extension_grouped_gemm.ipynb b/examples/python/02_pytorch_extension_grouped_gemm.ipynb index 567a583a42..b0cdb0edfd 100644 --- a/examples/python/02_pytorch_extension_grouped_gemm.ipynb +++ b/examples/python/02_pytorch_extension_grouped_gemm.ipynb @@ -8,7 +8,7 @@ "source": [ "# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n", "This notebook walks through a basic example of using the CUTLASS Python interface to declare\n", - "a grouped GEMM kernel and export it as a PyTorch CUDA extension.\n", + "a grouped GEMM kernel and export it as a PyTorch CUDA extension. Note that GEMM and Conv2d can also be exported as PyTorch CUDA extensions. \n", "\n", "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/tree/master/examples/00_basic_gemm.ipynb)\n", "\n", @@ -230,14 +230,6 @@ "print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n", "print('Speedup: {:.3f}'.format(nongrouped / grouped))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f22fc696", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/examples/python/03_basic_conv2d.ipynb b/examples/python/03_basic_conv2d.ipynb new file mode 100644 index 0000000000..962add39af --- /dev/null +++ b/examples/python/03_basic_conv2d.ipynb @@ -0,0 +1,423 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basic example of using the CUTLASS Python interface for Conv2d\n", + "\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run Conv2d. \n", + "\n", + "We first import various packages needed for the example and construct the input and output tensors that will be used in our example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import random\n", + "\n", + "import cutlass\n", + "\n", + "# This controls whether the C++ GEMM declaration will be printed at each step. \n", + "# Set to `false` to omit this information.\n", + "print_module = True\n", + "\n", + "# Input tensor: [N, H, W, C] under the channel-last layout\n", + "N, H, W, C = [32, 28, 28, 64]\n", + "\n", + "# Weight tensor: [K, R, S, C] under the channel-last layout\n", + "K, R, S = [128, 3, 3]\n", + "\n", + "# Stride, and padding\n", + "stride = (2, 2)\n", + "padding = (1, 1)\n", + "dilation = (1, 1)\n", + "\n", + "# Compute the output size [N, P, Q, K]\n", + "N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)\n", + "\n", + "dtype = torch.float16\n", + "type_A = torch.float16\n", + "type_B = torch.float16\n", + "type_C = torch.float16\n", + "type_D = torch.float16\n", + "\n", + "torch.manual_seed(1234)\n", + "\n", + "input = torch.ceil(\n", + " torch.empty(size=(N, C, H, W), dtype=type_A, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "weight = torch.ceil(\n", + " torch.empty(size=(K, C, R, S), dtype=type_B, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "tensor_C = torch.ceil(\n", + " torch.empty(size=(N, K, P, Q), dtype=type_B, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "output = torch.zeros_like(tensor_C)\n", + "\n", + "alpha = 1.0\n", + "beta = 0.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Declaring and running a Conv2d Fprop\n", + "\n", + "We first show you how to run a Conv2d in the forward propagation. To get started, one only needs to provide the tensors declared above to the `cutlass.op.Conv2dFprop` call. This sets up a default Conv2d fprop operation for the given device on which you are running. \n", + "\n", + "Assuming that we are runing on SM80, the default is a Conv2d that leverages FP16 Tensor Core operations.\n", + "\n", + "Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Specifying `element_accumulator` is not required if it is the same as `element`\n", + "plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are many other ways to construct a plan from `cutlass.op.Conv2dFprop` (e.g., by specifying the types of each operand, by providing representative tensors as input). For more details on these, see the documentation in the `cutlass.op.Conv2dFprop` constructor.\n", + "\n", + "We then compare the output to running the Conv2d using PyTorch. PyTorch use NCHW layout by default, so permutations are required." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_torch = alpha * torch.ops.aten.conv2d(\n", + " input, weight, stride=stride, padding=padding, dilation=dilation\n", + ") + beta * tensor_C\n", + "\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that one could use the same kernel just declared for tensors provided by other frameworks beyond PyTorch, such as NumPy." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Declaring and running Conv2d Dgrad and Wgrad\n", + "\n", + "The Python interface also supports declaring and running backward kernels of Conv2d. To begin with, we construct the tensors for the gradient of input, output, and weight." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grad_output = torch.ceil(\n", + " torch.empty(size=(N, K, P, Q), dtype=type_A, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "grad_input = torch.zeros_like(input)\n", + "grad_weight = torch.zeros_like(weight)\n", + "\n", + "tensor_C_dgrad = torch.ceil(\n", + " torch.empty(size=(N, C, H, W), dtype=type_A, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "tensor_C_wgrad = torch.ceil(\n", + " torch.empty(size=(K, C, R, S), dtype=type_B, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The script below gives a simple example of computing a data gradient via the CUTLASS Python interface and via PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan_dgrad = cutlass.Conv2dDgrad(element=dtype, element_accumulator=torch.float32)\n", + "plan_dgrad.run(grad_output, weight, tensor_C_dgrad, grad_input, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "\n", + "grad_input_torch = alpha * torch.nn.grad.conv2d_input(\n", + " (N, C, H, W),\n", + " weight, grad_output,\n", + " stride=stride, padding=padding\n", + ") + beta * tensor_C_dgrad\n", + "\n", + "assert torch.equal(grad_input_torch, grad_input)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The script below gives a simple example of computing a weight gradient via the CUTLASS Python interface and via PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan_wgrad = cutlass.Conv2dWgrad(element=dtype, element_accumulator=torch.float32)\n", + "plan_wgrad.run(grad_output, input, tensor_C_wgrad, grad_weight, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "\n", + "grad_weight_torch = alpha * torch.nn.grad.conv2d_weight(\n", + " input, (K, C, R, S), grad_output,\n", + " stride=stride, padding=padding\n", + ") + beta * tensor_C_wgrad\n", + "\n", + "assert torch.equal(grad_weight_torch, grad_weight)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running non-default Conv2ds\n", + "\n", + "The previous examples showed how it is simple to get starting running a default Conv2d kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the Conv2d? CUTLASS Python interface exposes mutable parameters that can be set after the `plan` initialization. We summarize these in the table below.\n", + "\n", + "|Parameter|Description|\n", + "| -- | -- |\n", + "|`tile_description`|The threadblock tile size, warp count, software pipeline stages, and instruction shape|\n", + "|`iterator_algorithm`|The iterator algorithm used to access the source operands|\n", + "|`swizzling_stride`|The stride of the threadblock swizzling functor|\n", + "|`split-K`|Partitions the reduction dimension to different threadblocks|" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tile Description\n", + "\n", + "The `tile_description` defines the tiling size of each threadblock, the warp count along each dimension of the tile, the software pipeline stages, and the instruction size. Under the hood, CUTLASS enumerates the different Conv2d configuration parameters for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernel (e.g., threadblock and warp shape)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.opclass = \"tensor_op\"\n", + "tiles = plan.tile_descriptions()\n", + "print(f'{len(tiles)} tile descriptions returned')\n", + "num_print = 10\n", + "print(f'First {num_print} tile descriptions are:')\n", + "for td in tiles[:num_print]:\n", + " print(td)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll pick one of these configurations at random and compile and run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random.seed(42)\n", + "idx = random.randint(0, len(tiles)-1)\n", + "td = tiles[idx]\n", + "print(f'Tile description {idx} is: {td}')\n", + "plan.tile_description = td\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Besides tile descriptions enumerated by CUTLASS, the users can also explicitly set the `threadblockshape`, `warp_shape`, `stages`, `instruction_shape`, and `cluster_shape`. If the configuration is invalid, an exception will be raised at `plan.run()` and the detailed compilation error will be stored in `./cutlass_python_compilation_error.txt` for debugging." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if plan.cc == 70:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [64, 256, 32],\n", + " \"warp_count\": [1, 4, 1],\n", + " \"stages\": 2,\n", + " \"instruction_shape\": [8, 8, 4], # optional,\n", + " \"cluster_shape\": [1, 1, 1] # optional, only [1, 1, 1] is supported currently\n", + " }\n", + "elif plan.cc == 75:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [128, 64, 32],\n", + " \"warp_count\": [2, 1, 1],\n", + " \"stages\": 2,\n", + " \"instruction_shape\": [16, 8, 8], # optional,\n", + " \"cluster_shape\": [1, 1, 1] # optional, only [1, 1, 1] is supported currently\n", + " }\n", + "elif plan.cc == 80:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [128, 128, 64],\n", + " \"warp_count\": [2, 2, 1],\n", + " \"stages\": 4,\n", + " \"instruction_shape\": [16, 8, 16], # optional,\n", + " \"cluster_shape\": [1, 1, 1] # optional, only [1, 1, 1] is supported currently\n", + " }\n", + "elif plan.cc == 86:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [128, 64, 64],\n", + " \"warp_count\": [2, 2, 1],\n", + " \"stages\": 3,\n", + " \"instruction_shape\": [16, 8, 16],\n", + " \"cluster_shape\": [1, 1, 1]\n", + " }\n", + "\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Iterator Algorithm\n", + "\n", + "The iterator algorithm describes how sources are loaded from memory. There are some iterator algorithms optimized for specific alignments and input/output channels that have better performance. The table below illustrates the available iterator algorithms.\n", + "\n", + "|Conv Kind | Iterator Algorithm | Description |\n", + "| -- | -- | -- |\n", + "|Fprop | \"analytic\" | Functionally correct in all cases but lower performance |\n", + "| | \"optimized\" | Optimized for and requires `R <= 32`, `S<= 32`, and `C % alignment_input == 0`|\n", + "| | \"few_channels\" | optimized for small `C` and requires `C % alignment_input == 0`|\n", + "| | \"fixed_channels\" | optimized for small `C` and requires `C == alignment_input` |\n", + "|Dgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n", + "| | \"optimized\" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n", + "|Wgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n", + "| | \"optimized\" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|\n", + "\n", + "By default, the Python interface will automatically propose a suitable iterator algorithm based on the input tensors in `plan.run()`. However, the user can also specify the desired iterator algorithm as follows" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.iterator_algorithm = \"analytic\"\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the iterator algorithm is invalid for the problem size in `plan.run()`, an exception will be raised." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Swizzling Stride\n", + "The swizzling changes how the tile are mapped to threadblocks to improve the L2 Locality. Given a swizzling stride `N`, the threadblock `(tb_x, tb_y)` computes tile `(tb_x / N, tb_y * N + (tb_x % N))`. Currently, stride values of `1`, `2`, `4`, and `8` are supported for `fprop`, `wgrad`, and `1`, and `4` for `dgrad`. The swizzling stride can be set with:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.swizzling_stride = 4\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split-K\n", + "Split-K is usually applied when the Conv2d has small spatial dimensions and large reduction dimension to ensure good utilization. It further partitions the reduction dimension to different threadblocks. The CUTLASS Python interface supports two types of split-K strategies: `Parallel`, and `Serial`. \n", + "* `Parallel`: the partial results from different threadblocks are stored in a temporary buffer in the global memory. When the Conv2d is done, a separate reduction kernel is created and launched to reduce the partial results.\n", + "* `Serial`: A semaphore is used to coordinate the order of different threadblocks adding their partial results to a given output tile. A separate kernel does not need to be launched for prforming the reduction.\n", + "\n", + "While all `fprop`, `dgrad`, and `wgrad` support split-K, here we use `wgrad` as an example. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parallel Split-K with 5 slices\n", + "grad_weight_parallel = torch.zeros_like(grad_weight)\n", + "plan_wgrad.run(\n", + " grad_output, input, tensor_C_wgrad, grad_weight_parallel, \n", + " stride, padding, dilation, alpha, beta, print_module=print_module, split_k=(\"parallel\", 5))\n", + "assert torch.equal(grad_weight_torch, grad_weight_parallel)\n", + "\n", + "# Serial Split-K with 3 slices\n", + "grad_weight_serial = torch.zeros_like(grad_weight)\n", + "plan_wgrad.run(\n", + " grad_output, input, tensor_C_wgrad, grad_weight_serial, \n", + " stride, padding, dilation, alpha, beta, print_module=print_module, split_k=(\"serial\", 3))\n", + "assert torch.equal(grad_weight_torch, grad_weight_serial)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/python/README.md b/examples/python/README.md index fab167c9df..2ed80e1939 100644 --- a/examples/python/README.md +++ b/examples/python/README.md @@ -12,3 +12,7 @@ Shows how to declare, compile, and run a grouped GEMM operation via the Python interface, along with how the emitted kernel can be easily exported to a PyTorch CUDA extension. + +* [03_basic_conv2d](/examples/python/03_basic_conv2d.ipynb) + + Shows how to declare, configure, compile, and run a CUTLASS Conv2d using the Python interface diff --git a/include/cute/algorithm/tensor_algorithms.hpp b/include/cute/algorithm/tensor_algorithms.hpp index 5fac8f92a6..294374b8b5 100644 --- a/include/cute/algorithm/tensor_algorithms.hpp +++ b/include/cute/algorithm/tensor_algorithms.hpp @@ -112,12 +112,47 @@ transform(Tensor& tensor_in, Tensor& ten } // Accept mutable temporaries -template +template CUTE_HOST_DEVICE constexpr void transform(Tensor&& tensor_in, Tensor&& tensor_out, UnaryOp&& op) { - return transform(tensor_in, tensor_out, std::forward(op)); + return transform(tensor_in, tensor_out, op); +} + +// Similar to std::transform with a binary operation +// Takes two tensors as input and one tensor as output. +// Applies the binary_op to tensor_in1 and and tensor_in2 and +// assigns it to tensor_out +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor& tensor_in1, + Tensor& tensor_in2, + Tensor& tensor_out, + BinaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in1); ++i) { + tensor_out(i) = static_cast(op)(tensor_in1(i), tensor_in2(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor&& tensor_in1, + Tensor&& tensor_in2, + Tensor&& tensor_out, + BinaryOp&& op) +{ + return transform(tensor_in1, tensor_in2, tensor_out, op); } } // end namespace cute diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp index ed338ccde2..d9ae200338 100644 --- a/include/cute/algorithm/tuple_algorithms.hpp +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -38,11 +38,38 @@ #include #include -/** Common algorithms on (hierarchical) tuples */ -/** Style choice: - * Forward params [using static_cast(.)] for const/non-const/ref/non-ref args - * but don't bother forwarding functions as ref-qualified member fns are extremely rare - */ +/// @file tuple_algorithms.hpp +/// @brief Common algorithms on (hierarchical) tuples +/// +/// Code guidelines and style preferences: +/// +/// For perfect forwarding, don't use std::forward, because it may not +/// be defined in device code when compiling with NVRTC. Instead, use +/// `static_cast(parameter_name)`. +/// +/// CuTe generally does not bother forwarding functions, as +/// reference-qualified member functions are rare in this code base. +/// +/// Throughout CUTLASS, cute::make_tuple always needs to be called +/// namespace-qualified, EVEN If inside the cute namespace and/or in +/// scope of a "using namespace cute" declaration. Otherwise, the +/// compiler may select std::make_tuple instead of cute::make_tuple, +/// due to argument-dependent lookup. Two problems may result from +/// that. +/// +/// 1. Functions have an unexpected return type (std::tuple instead of +/// cute::tuple), so functions that take cute::tuple parameters +/// fail to compile (generally inside functions that have template +/// parameters expected to be cute::tuple). +/// +/// 2. std::tuple does not have the required __host__ __device__ +/// markings, so the CUDA compiler complains if you use it in +/// device code. +/// +/// cute::make_tuple will occur more often than std::make_tuple would +/// in modern C++ code, because cute::tuple's design deprioritizes +/// correct operation of CTAD (constructor template argument +/// deduction) in favor of implementation simplicity. namespace cute { @@ -142,7 +169,13 @@ CUTE_HOST_DEVICE constexpr void for_each(T&& t, F&& f) { - detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); + if constexpr (is_tuple>::value) { + return detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; } template @@ -159,6 +192,36 @@ for_each_leaf(T&& t, F&& f) CUTE_GCC_UNREACHABLE; } +// +// For Sequence +// (s, t, f) => (f(t[s_0]),f(t[s_1]),...,f(t[s_n])) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +for_sequence(seq const&, F&& f) { + (f(Int{}), ...); +} + +}; // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +void +for_sequence(seq const& s, T&& t, F&& f) { + detail::for_sequence(s, [&](auto&& i){ f(get::value>(static_cast(t))); }); +} + +template +CUTE_HOST_DEVICE constexpr +void +for_sequence(T&& t, F&& f) { + for_sequence(make_seq{}, static_cast(t), static_cast(f)); +} + // // Transform // (t, f) => (f(t_0),f(t_1),...,f(t_n)) @@ -169,7 +232,13 @@ CUTE_HOST_DEVICE constexpr auto transform(T const& t, F&& f) { - return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; } template @@ -177,8 +246,14 @@ CUTE_HOST_DEVICE constexpr auto transform(T0 const& t0, T1 const& t1, F&& f) { - static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); - return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; } template @@ -186,9 +261,15 @@ CUTE_HOST_DEVICE constexpr auto transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) { - static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); - static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); - return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1, t2); + } + + CUTE_GCC_UNREACHABLE; } template @@ -399,7 +480,7 @@ fold_first(T&& t, F&& f) } // -// front, back, take, unwrap +// front, back, take, select, unwrap // // Get the first non-tuple element in a hierarchical tuple @@ -425,7 +506,16 @@ back(T&& t) { if constexpr (is_tuple>::value) { constexpr int N = tuple_size>::value; - return back(get(static_cast(t))); + + // MSVC needs a bit of extra help here deducing return types. + // We help it by peeling off the nonrecursive case a level "early." + if constexpr (! is_tuple(static_cast(t)))>>::value) { + return get(static_cast(t)); + } + else { + return back(get(static_cast(t))); + } + } else { return static_cast(t); } @@ -442,6 +532,47 @@ take(T const& t) return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); } +// +// Select tuple elements with given indices. +// + +template +CUTE_HOST_DEVICE constexpr +auto +select(T const & t) +{ + return cute::make_tuple(get(t)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +select(T const & t, Indices const & indices) +{ + if constexpr (is_tuple::value) { + return cute::transform(indices, [&t](auto i) { return select(t, i); }); + } + else { + static_assert(is_static::value, "Order must be static"); + return get(t); + } +} + +// Wrap non-tuples into rank-1 tuples or forward +template +CUTE_HOST_DEVICE constexpr +auto +wrap(T const& t) +{ + if constexpr (is_tuple::value) { + return t; + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + // Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple template CUTE_HOST_DEVICE constexpr @@ -576,7 +707,11 @@ CUTE_HOST_DEVICE constexpr auto repeat(X const& x) { - return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); + if constexpr (N == 1) { + return x; + } else { + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); + } } // @@ -605,7 +740,23 @@ CUTE_HOST_DEVICE constexpr auto group(T const& t) { - return detail::construct(t, take(t), make_seq{}, seq<0>{}, make_range::value>{}); + if constexpr (not is_tuple::value) { + if constexpr (E == -1) { + return group(t); + } else { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range{}); + } + } else + if constexpr (E == -1) { + return group::value>(t); + } else + if constexpr (B <= E) { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range::value>{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; } // @@ -685,6 +836,48 @@ prepend(T const& a, X const& x) CUTE_GCC_UNREACHABLE; } +// +// Unflatten a flat tuple into a hierarchical one +// unflatten(x, flatten(x)) == x +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + if constexpr (is_tuple::value) { + return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { + auto [result, remaining_tuple] = v; + auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); + return cute::make_tuple(append(result, sub_result), sub_tuple); + }); + } else { + return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// @pre flatten(@a flat_tuple) == @a flat_tuple +// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) +// @post congruent(@a result, @a target_profile) +// @post flatten(@a result) == @a flat_tuple +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); + CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); + return unflatten_tuple; +} + + // // Inclusive scan (prefix sum) // @@ -872,4 +1065,18 @@ zip2_by(T const& t, TG const& guide) CUTE_GCC_UNREACHABLE; } +/// @return A tuple of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr auto +reverse(T const& t) { + if constexpr (is_tuple::value) { + return detail::apply(t, [] (auto const&... a) { + return cute::make_tuple(a...); + }, tuple_rseq{}); + } + else { + return t; + } +} + } // end namespace cute diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 489998b589..aaef8b4161 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -68,7 +68,7 @@ initialize_barrier(uint64_t& smem_barrier, // 64 bits user-mange #endif } -// Set the number of bytes transfered per transaction +// Set the number of bytes transfered per transaction and perform an arrive operation as well CUTE_HOST_DEVICE void set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem @@ -134,27 +134,32 @@ enum class SmemSwizzleBits : uint8_t { B128 = 3, }; -#if !defined(__CUDACC_RTC__) #if (__CUDACC_VER_MAJOR__ >= 12) +#if !defined(__CUDACC_RTC__) +/// @return The TMA descriptor datatype enum corresponding to T. template -inline CUtensorMapDataType to_CUtensorMapDataType() { - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else - if constexpr (is_same::value) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else +inline CUtensorMapDataType +to_CUtensorMapDataType() { + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } } -inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { +inline CUtensorMapSwizzle +to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { switch (t) { default: assert(false && "Unknown SmemSwizzleBits!"); case SmemSwizzleBits::DISABLE: return CU_TENSOR_MAP_SWIZZLE_NONE; @@ -163,15 +168,16 @@ inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t) { case SmemSwizzleBits::B128: return CU_TENSOR_MAP_SWIZZLE_128B; } } +#endif // !defined(__CUDACC_RTC__) #endif // (__CUDACC_VER_MAJOR__ >= 12) -#endif // !defined(__CUDACC_RTC__) + } // end namespace TMA #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) -using TmaDescriptor = CUtensorMap; + using TmaDescriptor = CUtensorMap; #else -using TmaDescriptor = struct { char bytes[128]; }; + using TmaDescriptor = struct { char bytes[128]; }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// /// Initiates a TensorMap Prefetch diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index 412754c7a2..46cace385d 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -34,7 +34,6 @@ #include #include - namespace cute { @@ -503,8 +502,7 @@ struct SM90_TMA_LOAD_MULTICAST struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D { CUTE_HOST_DEVICE static void - copy(void const* const desc_ptr, uint64_t& smem_mbar, - uint16_t const& multicast_mask, + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, uint16_t const& offset_w) @@ -532,12 +530,10 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D { CUTE_HOST_DEVICE static void - copy(void const* const desc_ptr, uint64_t& smem_mbar, - uint16_t const& multicast_mask, + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, - uint16_t const& offset_w, - uint16_t const& offset_h) + uint16_t const& offset_w, uint16_t const& offset_h) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); @@ -545,7 +541,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); // Copy from global to shared::cluster. asm volatile ( - "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), @@ -562,13 +558,10 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D { CUTE_HOST_DEVICE static void - copy(void const* const desc_ptr, uint64_t& smem_mbar, - uint16_t const& multicast_mask, + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, - uint16_t const& offset_w, - uint16_t const& offset_h, - uint16_t const& offset_d) + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) { #if defined(CUTE_ARCH_TMA_SM90_ENABLED) uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); @@ -576,7 +569,7 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); // Copy from global to shared::cluster. asm volatile ( - "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" : : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), @@ -593,45 +586,39 @@ struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D struct SM90_TMA_LOAD_IM2COL_MULTICAST { CUTE_HOST_DEVICE static void - copy(void const* const desc_ptr, uint64_t& smem_mbar, - uint16_t const& multicast_mask, + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, uint16_t const& offset_w) { - return SM90_TMA_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, smem_mbar, - multicast_mask, smem_ptr, - coord_c, coord_w, coord_n, - offset_w); + return SM90_TMA_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, smem_mbar, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); } CUTE_HOST_DEVICE static void - copy(void const* const desc_ptr, uint64_t& smem_mbar, - uint16_t const& multicast_mask, + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, - uint16_t const& offset_w, - uint16_t const& offset_h) + uint16_t const& offset_w, uint16_t const& offset_h) { - return SM90_TMA_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, smem_mbar, - multicast_mask, smem_ptr, - coord_c, coord_w, coord_h, coord_n, - offset_w, offset_h); + return SM90_TMA_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, smem_mbar, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); } CUTE_HOST_DEVICE static void - copy(void const* const desc_ptr, uint64_t& smem_mbar, - uint16_t const& multicast_mask, + copy(void const* const desc_ptr, uint64_t& smem_mbar, uint16_t multicast_mask, void const* const smem_ptr, int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, - uint16_t const& offset_w, - uint16_t const& offset_h, - uint16_t const& offset_d) + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) { - return SM90_TMA_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, smem_mbar, - multicast_mask, smem_ptr, + return SM90_TMA_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, smem_mbar, multicast_mask, + smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n, - offset_w, offset_h, offset_d); + offset_w, offset_h, offset_d); } }; diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp index 42778c808f..25a98e6cb0 100644 --- a/include/cute/arch/mma_sm90.hpp +++ b/include/cute/arch/mma_sm90.hpp @@ -502,6 +502,151 @@ ss_op_selector() static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); } } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + else { static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); } @@ -809,6 +954,150 @@ rs_op_selector() } } + // FP8 + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E4M3E5M2_RS_TN{}; + } + else { + static_aRSert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // FP8 + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90_64x256x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 192 == 0) { + return SM90_64x192x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 128 == 0) { + return SM90_64x128x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 96 == 0) { + return SM90_64x96x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 64 == 0) { + return SM90_64x64x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 32 == 0) { + return SM90_64x32x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 16 == 0) { + return SM90_64x16x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90_64x8x32_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + else { static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); } diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp index dd4e1fb59d..69469d5697 100644 --- a/include/cute/arch/mma_sm90_desc.hpp +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -89,6 +89,28 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { union GmmaDescriptor { + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor && t) noexcept : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor && t) noexcept { + desc_ = t.desc_; + return *this; + } + uint64_t desc_; uint32_t reg32_[2]; uint16_t reg16_[4]; @@ -112,7 +134,7 @@ union GmmaDescriptor // layout type, bit [62,64) // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) - }; + } bitfield; // Decay to a uint64_t CUTE_HOST_DEVICE constexpr @@ -123,11 +145,11 @@ union GmmaDescriptor { #if !defined(__CUDACC_RTC__) printf("GmmaDescriptor: 0x%016" PRIx64 "\n", t.desc_); - printf(" start_addr : 0x%04x\n", t.start_address_); - printf(" leading_off: 0x%04x (%d)\n", t.leading_byte_offset_, t.leading_byte_offset_); - printf(" stride_off : 0x%04x (%d)\n", t.stride_byte_offset_, t.stride_byte_offset_); - printf(" base_offset: 0x%01x\n", t.base_offset_); - printf(" layout_type: 0x%01x (%s)\n", t.layout_type_, to_string(static_cast(t.layout_type_))); + printf(" start_addr : 0x%04x\n", t.bitfield.start_address_); + printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); + printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); + printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); #endif } }; diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp index db4083ee82..dc0abb57a2 100644 --- a/include/cute/arch/mma_sm90_gmma.hpp +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -32,7 +32,6 @@ #include #include - // Config #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) # define CUTE_ARCH_MMA_SM90A_ENABLED @@ -59,7 +58,7 @@ CUTE_HOST_DEVICE void warpgroup_wait() { - static_assert(N >= 0 && N <= 7, "_warpgroup.wait {N}; must be in range [0, 7]"); + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); #if defined(CUTE_ARCH_MMA_SM90A_ENABLED) asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); #else @@ -12781,4 +12780,7860 @@ struct SM90_64x256x32_S32U8U8_RS_TN_SATURATE //////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x8x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x16x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x32x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x64x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x96x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x128x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x192x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct SM90_64x256x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_RUNTIME_ASSERT("Attempting to use SM90_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index cfffbcab55..098aa8caf3 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -479,6 +479,12 @@ make_tiled_copy_C_atom(Copy_Atom const& copy_atom, return make_tiled_copy_impl(copy_atom, layout_tv, tiler); } +/** Produce a TiledCopy from logical thread and values layouts. + * The thread and value layouts map coordinates to thr_idx and val_idx. + * The product of these layouts is taken to produce the TV layout and the Tiler. + * Useful when threads and values need very specific mappings onto coordinates + * in the target tensors. + */ template > @@ -486,7 +492,7 @@ CUTE_HOST_DEVICE auto make_tiled_copy(Copy_Atom const& copy_atom, ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx - ValLayout const& val_layout = {}) + ValLayout const& val_layout = {}) // (m,n) -> val_idx { constexpr int R = cute::max(rank_v, rank_v); @@ -496,14 +502,82 @@ make_tiled_copy(Copy_Atom const& copy_atom, // Take the raked_products to compute the Layout_MN auto layout_mn = raked_product(thr_layout_mn, val_layout_mn); auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); - // print("thr_layout: "); print(thr_layout_mn); print("\n"); - // print("val_layout: "); print(val_layout_mn); print("\n"); - // print("layout_mn : "); print(layout_mn); print("\n"); - // print("layout_tv : "); print(layout_tv); print("\n"); + // print("thr_layout: "); print(thr_layout_mn); print("\n"); + // print("val_layout: "); print(val_layout_mn); print("\n"); + // print("layout_mn : "); print(layout_mn); print("\n"); + // print("layout_tv : "); print(layout_tv); print("\n"); return make_tiled_copy_impl(copy_atom, layout_tv, product_each(shape(layout_mn))); } +/** Produce a TiledCopy from thread and value offset maps. + * The TV Layout maps threads and values to the codomain of the data_layout. + * It is verified that the intended codomain is valid within data_layout. + * Useful when threads and values don't care about owning specific coordinates, but + * care more about the vector-width and offsets between them. + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_cotiled_copy(Copy_Atom const& copy_atom, + AtomTVLayout const& atom_tv_layout, // atom (thr,val) -> data addr + DataLayout const& data_layout) // coord -> data addr The target layout +{ + static_assert(is_static::value); + static_assert(is_static::value); + + // data addr -> data coord Append 1:0 so off-the-ends get the stride-0 + auto inv_data_layout = make_layout(left_inverse(data_layout), Layout<_1,_0>{}); + + // (tid,vid) -> data_coord + auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); + + // Check validity + CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); + +#if 0 + if (thread0()) { + print("data_layout : "); print(data_layout); print("\n"); + print("atom_tv_layout : "); print(atom_tv_layout); print("\n"); + print("layout_tv_data : "); print(layout_tv_data); print("\n"); + } +#endif + + // + // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled DATA + auto flat_data_shape = product_each(shape(data_layout)); + auto flat_data_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_data_shape, replace(flat_data_zeros, Int<1>{})), layout_tv_data)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> data_coord + auto tile2data = composition(make_layout(flat_data_shape), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2data), layout_tv_data); + +#if 0 + if (thread0()) { + print("tiler : "); print(tiler); print("\n"); + print("tile2data : "); print(tile2data); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + } +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + // Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy template diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index cea03c0ff0..9c4821d90d 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -128,4 +128,20 @@ copy_unpack(Copy_Traits const&, rD, make_int_sequence{}); } +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor && dst) +{ + copy_unpack(traits, src, dst); +} + } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 6f3f9d4d8b..9b91f87ef4 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -34,11 +34,7 @@ #include #endif -#include -#include - -#include - +#include #include #include @@ -52,15 +48,15 @@ namespace cute struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; // The executable SM90_TMA_LOAD with tma_desc and tma_mbar -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; @@ -75,21 +71,8 @@ struct Copy_Traits copy_unpack_(void const* const dst_ptr, Coord const& src_coord, seq) const { -#if 0 - print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", - threadIdx.x, threadIdx.y, threadIdx.z, - blockIdx.x, blockIdx.y, blockIdx.z); - print(" TMA Coord "); print(src_coord); print("\n"); - print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), - uint64_t(tma_desc_.size1_), - uint64_t(tma_desc_.size2_), - uint64_t(tma_desc_.size3_))); print("\n"); -#endif - - SM90_TMA_LOAD::copy(&tma_desc_, - tma_load_mbar_, - dst_ptr, - get(src_coord)...); + SM90_TMA_LOAD::copy(&tma_desc_, tma_load_mbar_, + dst_ptr, get(src_coord)...); } // This is the copy_unpack dispatch for this Copy_Traits @@ -103,24 +86,23 @@ struct Copy_Traits Tensor const& src, Tensor & dst) { - //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD"); - traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); + traits.copy_unpack_(raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq{}); } }; // The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar // Use .with(tma_mbar) to construct an executable version -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; @@ -138,7 +120,7 @@ struct Copy_Traits // Construct an executable SM90_TMA_LOAD with tma_mbar CUTE_HOST_DEVICE constexpr - Copy_Traits + Copy_Traits with(uint64_t& tma_mbar, uint16_t const& multicast_mask = 0) const { // We accept multicast_mask here to keep the API for both atoms consistent // assert(multicast_mask == 0); @@ -152,10 +134,7 @@ struct Copy_Traits auto get_tma_tensor(GShape const& g_shape) const { static_assert(is_congruent::value); - constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; - return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), - g_shape, - g_stride_); + return make_counting_tensor(make_layout(g_shape, g_stride_)); } // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() @@ -173,15 +152,15 @@ struct Copy_Traits struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; @@ -197,22 +176,8 @@ struct Copy_Traits copy_unpack_(void const* const dst_ptr, Coord const& src_coord, seq) const { -#if 0 - print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", - threadIdx.x, threadIdx.y, threadIdx.z, - blockIdx.x, blockIdx.y, blockIdx.z); - print(" TMA Coord "); print(src_coord); print("\n"); - print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), - uint64_t(tma_desc_.size1_), - uint64_t(tma_desc_.size2_), - uint64_t(tma_desc_.size3_))); print("\n"); -#endif - - SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_, - tma_load_mbar_, - multicast_mask_, - dst_ptr, - get(src_coord)...); + SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_, tma_load_mbar_, multicast_mask_, + dst_ptr, get(src_coord)...); } template Tensor const& src, Tensor & dst) { - //static_assert(is_gmem::value, "Expected gmem src for SM90_TMA_LOAD"); // TMA spoofed src tensor static_assert(is_smem::value, "Expected smem dst for SM90_TMA_LOAD_MULTICAST"); - traits.copy_unpack_(dst.data().get(), src.data().coord_, tuple_seq{}); + traits.copy_unpack_(raw_pointer_cast(dst.data()), src.data().coord_, tuple_seq{}); } }; -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; @@ -256,7 +220,7 @@ struct Copy_Traits // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar CUTE_HOST_DEVICE constexpr - Copy_Traits + Copy_Traits with(uint64_t& tma_load_mbar, uint16_t const& multicast_mask) const { return {tma_desc_, tma_load_mbar, multicast_mask}; } @@ -267,10 +231,7 @@ struct Copy_Traits auto get_tma_tensor(GShape const& g_shape) const { static_assert(is_congruent::value); - constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; - return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), - g_shape, - g_stride_); + return make_counting_tensor(make_layout(g_shape, g_stride_)); } // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() @@ -287,15 +248,15 @@ struct Copy_Traits ////////////////////////////////////////////////////////////////////////////// // The executable SM90_TMA_STORE with tma_desc -template -struct Copy_Traits +template +struct Copy_Traits { using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; @@ -317,10 +278,7 @@ struct Copy_Traits auto get_tma_tensor(GShape const& g_shape) const { static_assert(is_congruent::value); - constexpr int tma_rank = decltype(cute::min(rank(flatten(g_stride_)), Int<5>{}))::value; - return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat(Int<0>{}))), - g_shape, - g_stride_); + return make_counting_tensor(make_layout(g_shape, g_stride_)); } template @@ -329,20 +287,8 @@ struct Copy_Traits copy_unpack_(void const* const src_ptr, Coord const& dst_coord, seq) const { -#if 0 - print("THR (%d,%d,%d) BLK (%d,%d,%d)\n", - threadIdx.x, threadIdx.y, threadIdx.z, - blockIdx.x, blockIdx.y, blockIdx.z); - print(" TMA Coord "); print(dst_coord); print("\n"); - print(" TMA Shape "); print(make_tuple(uint64_t(tma_desc_.size0_), - uint64_t(tma_desc_.size1_), - uint64_t(tma_desc_.size2_), - uint64_t(tma_desc_.size3_))); print("\n"); -#endif - SM90_TMA_STORE::copy(&tma_desc_, - src_ptr, - get(dst_coord)...); + src_ptr, get(dst_coord)...); } // This is the copy_unpack dispatch for this Copy_Traits @@ -359,7 +305,7 @@ struct Copy_Traits static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor - traits.copy_unpack_(src.data().get(), dst.data().coord_, tuple_seq{}); + traits.copy_unpack_(raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); } }; @@ -367,17 +313,17 @@ struct Copy_Traits ///////////////////////////// BULK COPY ////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////// -template -struct Copy_Traits +template +struct Copy_Traits { - static_assert(int32_t(NumBits::value / 8) % 16 == 0, + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, "Bulk Copy requires copy vector size align to 16B."); using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; @@ -398,28 +344,28 @@ struct Copy_Traits static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_COPY_G2S"); static_assert(is_smem::value, "Expected smem dst for SM90_BULK_COPY_G2S"); SM90_BULK_COPY_G2S::copy(src.data().get(), *get<0>(traits.bulk_load_mbar_), - dst.data().get(), int32_t(NumBits::value / 8)); + dst.data().get(), int32_t(NumBitsPerTMA::value / 8)); } // Record the memory barrier for the instruction CUTE_HOST_DEVICE constexpr - Copy_Traits + Copy_Traits with(uint64_t& bulk_mbar) const { return {{&bulk_mbar}}; } }; -template -struct Copy_Traits +template +struct Copy_Traits { - static_assert(int32_t(NumBits::value / 8) % 16 == 0, + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, "Bulk Copy requires copy vector size align to 16B."); using ThrID = Layout<_1>; // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; + using SrcLayout = Layout>; // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; + using DstLayout = Layout>; // Reference map from (thr,val) to bit using RefLayout = SrcLayout; @@ -433,7 +379,7 @@ struct Copy_Traits { static_assert(is_smem::value, "Expected smem src for SM90_BULK_COPY_S2G"); static_assert(is_gmem::value, "Expected gmem dst for SM90_BULK_COPY_S2G"); - SM90_BULK_COPY_S2G::copy(src.data().get(), dst.data().get(), int32_t(NumBits::value / 8)); + SM90_BULK_COPY_S2G::copy(src.data().get(), dst.data().get(), int32_t(NumBitsPerTMA::value / 8)); } }; @@ -469,112 +415,98 @@ struct Copy_Traits // MAKE_TMA_COPY and related // -namespace detail -{ - -template -auto -get_swizzle_portion(ComposedLayout,Offset,SLayout>) -{ - return Swizzle{}; -} +namespace detail { -template -auto -get_swizzle_portion(Layout) -{ - return Swizzle<0,4,3>{}; -} - -template -auto -get_nonswizzle_portion(ComposedLayout,Offset,SLayout> const& slayout) -{ - return slayout.layout_fn(); -} - -template -auto -get_nonswizzle_portion(Layout const& slayout) -{ - return slayout; -} - -template -TMA::SmemSwizzleBits -get_tma_swizzle_bits(Swizzle) -{ - if constexpr (M == 4) { - switch (B) { - default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); - case 3: return TMA::SmemSwizzleBits::B128; - case 2: return TMA::SmemSwizzleBits::B64; - case 1: return TMA::SmemSwizzleBits::B32; - case 0: return TMA::SmemSwizzleBits::DISABLE; - } - } else - { - static_assert(M < 0, "Unsupported layout swizzle."); - } -} - -template -TMA::SmemSwizzleBits -get_tma_swizzle_bits(Layout const& layout) -{ - return get_tma_swizzle_bits(get_swizzle_portion(layout)); -} - -#if !defined(__CUDACC_RTC__) // Use a smem2gmode map to read through the GMEM tensor // and construct a TMA Descriptor for the resulting instruction template -CUTE_HOST +CUTE_HOST_RTC auto make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor - Layout const& smem_inv, // smem_idx to flat gmode + Layout const& smem_inv_h, // smem_idx to hier gmode Swizzle const& swizzle) // Swizzle fn on smem_idx { using T = typename GEngine::value_type; - auto flat_glayout = flatten(gtensor.layout()); - CUTE_STATIC_ASSERT_V(rank(flat_glayout) == rank(smem_inv)); - constexpr int rank_smem_inv = decltype(rank(smem_inv))::value; + // This is the gmem "vector" that corresponds to the smem vector in memory (smem_box_shape):(gmem_prob_stride) + Tensor tma_gstride = recast(gtensor.compose(smem_inv_h)); + + // If the sizes of smem_inv_h and tma_gstride don't match, then a non-trivial recast was performed. + // In that case, require that the recasted modes all have size-1 so TMA can identity them and skip them. + for_each(zip(flatten(shape(smem_inv_h)), flatten(shape(tma_gstride))), [] (auto s_and_g) { + auto [s,g] = s_and_g; + CUTE_STATIC_ASSERT_V(s == g or g == Int<1>{}, + "A non-trivial recast was performed, but TMA cannot identify which modes to leave out."); + }); + + // Perform the tiling to the gmem vector again, but with indirections to the gtensor modes + auto gbasis = make_identity_layout(shape(gtensor)); + auto tma_gbasis_tile_tmp = gbasis.compose(smem_inv_h); + // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape and coalesce out any size-1 modes + auto tma_gbasis_tile = coalesce(make_layout(shape(tma_gstride), stride(tma_gbasis_tile_tmp))); - auto tma_multimode = rank(flat_glayout) > Int<5>{}; - constexpr uint32_t tma_dim = cute::min(rank(flat_glayout), 5);; + // Recast the original tensor for shape inspections + auto glayout_T = recast(gtensor).layout(); + + // Find missing bases that don't belong to a size-1 mode of the recast input + // NOTE This is essentially ArithmeticTuple complement... + // NOTE in persuit of implementing an ArithmeticTuple logical_divide for smem_inv_h + auto tma_gbasis_full = fold(zip(flatten(shape(glayout_T)), flatten(stride(gbasis))), tma_gbasis_tile, + [](auto tma_g, auto s_and_d) { + auto [s,d] = s_and_d; + auto k = find(stride(tma_g), d); // Find the basis in tma_gstride + if constexpr (decltype(k != rank(tma_g) || is_constant<1, decltype(s)>{})::value) { + // If d was found or s is static-1, then don't append + return tma_g; + } else { + // Else, append the missing basis + return append(tma_g, make_layout(Int<1>{}, d)); + } + }); + + // Group the trailing modes to make this max rank-5 + auto tma_gbasis = group(tma_gbasis_full); + +#if 0 + print("gtensor : "); print(gtensor); print("\n"); + print("smem_inv_h : "); print(smem_inv_h); print("\n"); + print("tma_gstride : "); print(tma_gstride); print("\n"); + print("gbasis : "); print(gbasis); print("\n"); + print("tma_gb_tile : "); print(tma_gbasis_tile ); print("\n"); + print("tma_gbasis : "); print(tma_gbasis); print("\n"); +#endif + + constexpr int tma_dim = decltype(rank(tma_gbasis))::value; // // TMA gmem desc info // - void* gmem_address = (void*) gtensor.data(); + void* gmem_address = (void*) raw_pointer_cast(gtensor.data()); cute::array gmem_prob_shape = {1,1,1,1,1}; cute::array gmem_prob_stride = {0,0,0,0,0}; - for_each(make_seq{}, [&](auto i) { - auto e = stride(smem_inv); // For g++-7.5, let it deduce e rather than fuse with below - constexpr int j = decltype(e.mode())::value; - constexpr int tma_i = i < 5 ? i : 4; - - // Problem stride - uint64_t stride_j = stride(flat_glayout) * sizeof(T); - uint64_t old_stride = gmem_prob_stride[tma_i]; - gmem_prob_stride[tma_i] = gcd(gmem_prob_stride[tma_i], stride_j); - - // Problem shape - uint64_t shape_j = shape(flat_glayout); - if (gmem_prob_stride[tma_i] != 0) { - // We're "resetting" this TMA mode and using it as a "multimode" - // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 - gmem_prob_shape[tma_i] = (gmem_prob_shape[tma_i]-1) * (old_stride / gmem_prob_stride[tma_i]) - + (shape_j-1) * (stride_j / gmem_prob_stride[tma_i]) - + 1; - } else { - gmem_prob_shape[tma_i] = shape_j; - } + // Use the indirections in tma_gbasis in the values of flat_glayout to construct the gmem shapes/strides + for_each(make_seq{}, [&](auto i) { + for_each(stride(tma_gbasis), [&](auto ej) { + // Problem stride + uint64_t stride_j = basis_get(ej, stride(glayout_T)) * sizeof(T); + uint64_t old_stride = gmem_prob_stride[i]; + gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); + + // Problem shape + uint64_t shape_j = basis_get(ej, shape(glayout_T)); + if (gmem_prob_stride[i] != 0) { + // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 + gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) + + (shape_j-1) * (stride_j / gmem_prob_stride[i]) + + 1; + } else { + gmem_prob_shape[i] = shape_j; + } + }); }); assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned @@ -590,7 +522,9 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 - assert((gmem_prob_stride[0]) == sizeof(T)); // First stride is implicitly 1 + // TMA descriptor does not store the zeroth stride and assumes it is sizeof(T) == one element. + assert(gmem_prob_stride[0] == sizeof(T) && "Majorness of smem doesn't match majorness of gmem"); + assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 @@ -606,36 +540,30 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM cute::array smem_box_shape = {1,1,1,1,1}; cute::array smem_box_stride = {1,1,1,1,1}; - for_each(make_seq{}, [&](auto i) { - uint32_t shape_i = shape(smem_inv); - constexpr int tma_i = i < 5 ? i : 4; - if (tma_multimode && tma_i == 4) { - // We're "reusing" this TMA mode and using it as a "multimode" - smem_box_shape[tma_i] = 1; - } else { - smem_box_shape[tma_i] = shape_i; - } + // The smem box is simply given by the sizes of the modes in tma_gbasis + for_each(make_seq{}, [&](auto i) { + smem_box_shape[i] *= size(tma_gbasis); }); assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 assert(smem_box_shape[0] >= (uint64_t(1))); // Size must be min 1 - assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 + assert(smem_box_shape[0] <= (uint64_t(1) << 8)); // Size must be max 2^8 = 256 assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 - assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 = 8 assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 - assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 = 8 assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 - assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 = 8 assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 - assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 = 8 assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 - assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 + assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 = 8 // // Construct the descriptor @@ -647,11 +575,11 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM // TMA general info // -#if (__CUDACC_VER_MAJOR__ >= 12) +#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; - CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; // TMA smem swizzle type @@ -687,31 +615,43 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM assert(false); } -#endif // (__CUDACC_VER_MAJOR__ >= 12) +#endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) // Finally, get the inverse permutation of the E bases for the mocked gmem stride - auto gmem_stride_bases_flat = transform(make_seq{}, [&](auto i) { - auto k = find(stride(smem_inv), E{}); - // For gcc 7.5 -- avoid 'if constexpr' - int32_t tma_coord_stride = int32_t(stride(flat_glayout) * sizeof(T) / (gmem_prob_stride[4] != 0 ? gmem_prob_stride[4] : 16)); - return conditional_return(tma_multimode && (k >= Int<4>{}), - E<4>{} * tma_coord_stride, // The 4th TMA mode is the multimode, use int32_t coord stride - E{}); + // NOTE This is essentially ArithmeticTuple inverse... + auto gmem_stride_bases = transform_leaf(stride(gbasis), [&](auto ei) { + auto si = basis_get(ei, shape(glayout_T)); + auto di = basis_get(ei, stride(glayout_T)); + auto tma_gbasis_stride = stride(tma_gbasis); + // Find j such that E is in stride(tma_gbasis) + [[maybe_unused]] auto j = find_if(tma_gbasis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == ei; }); }); + // Return the TMA basis this gmode contributes to + if constexpr (is_constant<1, decltype(si)>::value || decltype(j == rank(tma_gbasis_stride))::value) { + return Int<0>{}; // Return arithmetic identity -- no contribution to the TMA + } else + if constexpr (decltype(rank(tma_gbasis_stride) == Int<1>{})::value) { + return E{}; // We know that the scale factor is Int<1>{} + } else { + return E{} * int32_t(di * sizeof(T) / cute::max(gmem_prob_stride[j], 16)); + } }); - // Give that the profile of gtensor and fold it - // NOTE: This is the only reason we want the original gtensor shape rather than the more intuitive flattened shape - auto gmem_stride_bases = stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), gmem_stride_bases_flat), - make_layout(repeat_like(shape(gtensor), Int<2>{})))); +#if 0 + print("gmem_stride_bases : "); print(gmem_stride_bases); print("\n"); +#endif - return make_tuple(tma_desc, gmem_stride_bases); + return cute::make_tuple(tma_desc, gmem_stride_bases); } +// The "logical TMA tid" is a map from the CTA rank to its logical id +// within the instruction. It works like a mask or ordering on the +// CTAs. For non-multicast TMA, all CTAs should map to 0. For +// multicast TMA of size 4, CTAs will be mapped to {0,1,2,3}. template -CUTE_HOST +CUTE_HOST_RTC auto make_tma_copy_tiled(CopyOp, Tensor const& gtensor, // Full GMEM Tensor @@ -732,8 +672,6 @@ make_tma_copy_tiled(CopyOp, // TMA slayout manipulation // - auto flat_glayout = flatten(gtensor.layout()); - // Invert the smem to get the largest contiguous vector in the smem layout auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); // trunc_smem_idx -> trunc_smem_coord @@ -741,32 +679,13 @@ make_tma_copy_tiled(CopyOp, // Map from smem idx to a gmem mode auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); - // Truncate any incompatibilities - auto smem_rank = find_if(stride(sidx_to_gmode), [](auto e) { - auto v = basis_value(e); - return not is_constant<1,decltype(v)>{}; - }); - static_assert(smem_rank > 0, "Could not find a common smem-gmem vectorization for TMA. Do they have a common majorness?"); - // TMA uses a maximum of 5 modes - // If the gtensor has more than 5 modes, we need to reserve the last TMA-mode as a "multimode" - constexpr int smem_tma_rank = cute::min(int(smem_rank), (rank(flat_glayout) > Int<5>{} ? 4 : 5)); - - // Keep only the static-1 basis modes into gmem - auto sidx_to_gmode_trunc = take<0,smem_tma_rank>(sidx_to_gmode); - - // Split according to the portion each multicast CTA will be responsible for - auto sidx_to_gmode_vt = logical_divide(sidx_to_gmode_trunc, shape_div(size(sidx_to_gmode_trunc), cosize(cta_t_map))); - #if 0 - print("g_layout : "); print(gtensor.layout()); print("\n"); - print("s_layout : "); print(slayout); print("\n"); - print("cta_t_map : "); print(cta_t_map); print("\n"); - print("cta_v_map : "); print(cta_v_map); print("\n"); - print("inv_smem : "); print(inv_smem_layout); print("\n"); - print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); - - print("sidx_to_gmode_trunc : "); print(sidx_to_gmode_trunc); print("\n"); - print("sidx_to_gmode_vt : "); print(sidx_to_gmode_vt); print("\n"); + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_t_map : "); print(cta_t_map); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); #endif // @@ -774,49 +693,51 @@ make_tma_copy_tiled(CopyOp, // // Generate a TupleBasis for the gtensor - auto flat_gbasis = make_basis_like(shape(flat_glayout)); - - // Fold the flat_gbasis into the glayout - auto glayout_basis = make_layout(shape(gtensor), - stride(composition(make_layout(repeat_like(shape(flat_glayout), Int<2>{}), flat_gbasis), - make_layout(repeat_like(shape(gtensor), Int<2>{}))))); + auto glayout_basis = make_identity_layout(shape(gtensor)); // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc - auto tma_layout_v_trunc = flatten(composition(glayout_basis, layout<0>(sidx_to_gmode_vt))); + auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); - // Append any missing basis on the end as size-1 modes b/c they got truncated - // NOTE This is essentially ArithmeticTuple complement... - auto missing_basis = fold(stride(tma_layout_v_trunc), flat_gbasis, [](auto init, auto e) { - auto k = find(init, e); - return remove(init); + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; }); + static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); + // TMA uses a maximum of 5 modes + // If the gtensor has more than 5 modes, we need to reserve the last TMA-mode as a "multimode" + constexpr int smem_tma_rank = cute::min(int(smem_rank), (rank(tma_layout_full) > Int<5>{} ? 4 : 5)); + + // Keep only the static-1 basis modes into gmem + auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); - // The appended map from truncated smem codomain to gmem mode: trunc_smem_idx -> gmem_mode - auto tma_layout_v = make_layout(flatten(cute::make_tuple(tma_layout_v_trunc.shape(), repeat(Int<1>{}))), - flatten(cute::make_tuple(tma_layout_v_trunc.stride(), missing_basis))); + // Split according to the portion each multicast CTA will be responsible for + auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map))); #if 0 - print("flat_gbasis : "); print(flat_gbasis); print("\n"); - print("missing_b : "); print(missing_basis); print("\n"); - print("tma_layout_v : "); print(tma_layout_v); print("\n"); + print("glayout_basis : "); print(glayout_basis); print("\n"); + print("tma_layout_full : "); print(tma_layout_full); print("\n"); + + print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); + print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); #endif // // Construct the TMA Desc and GMEM mode ordering // - auto [tma_desc, gmem_stride_bases] = detail::make_tma_copy_desc(gtensor, tma_layout_v, get_swizzle_portion(slayout)); + auto [tma_desc, gmem_stride_bases] = detail::make_tma_copy_desc(gtensor, layout<0>(tma_layout_vt), get_swizzle_portion(slayout)); // // Construct the Copy_Traits // using T = typename GEngine::value_type; - constexpr int num_bits = decltype(size<0>(sidx_to_gmode_vt))::value * sizeof(T) * 8; - using Traits = Copy_Traits, decltype(gmem_stride_bases)>; + constexpr int num_bits_per_tma = decltype(size<0>(tma_layout_vt))::value * sizeof(T) * 8; + using Traits = Copy_Traits, decltype(gmem_stride_bases)>; #if 0 - print("num_bits : "); print(num_bits); print("\n"); + print("num_bits : "); print(NumBitsPerTMA{}); print("\n"); print("g_stride_bases: "); print(gmem_stride_bases); print("\n"); #endif @@ -829,8 +750,12 @@ make_tma_copy_tiled(CopyOp, auto cta_tiler = product_each(shape(cta_v_map)); // (CTA V, CTA T) -> smem_coord - auto layout_vt = composition(inv_smem_layout, make_layout(shape(sidx_to_gmode_vt))); + auto layout_vt = composition(inv_smem_layout, make_layout(shape(tma_layout_vt))); // Scale that up to cover all of the smem_coords + // + // The smem vector might not cover all of the tile, + // so multiply it up to cover the entire tile. + // "T" here (the parallel index) is a CTA index. auto layout_VT = tile_to_shape(layout_vt, make_shape(size(cta_v_map)/size<1>(layout_vt), size<1>(layout_vt))); // Flip it and change the domain of the T from logical thr to thr_idx auto layout_TV = make_layout(composition(layout<1>(layout_VT), cta_t_map), layout<0>(layout_VT)); @@ -844,7 +769,6 @@ make_tma_copy_tiled(CopyOp, using T = typename GEngine::value_type; return TiledCopy, decltype(layout_TV), decltype(cta_tiler)>{tma_traits}; } -#endif // !defined(__CUDACC_RTC__) } // end namespace detail @@ -920,13 +844,12 @@ make_tma_copy_tiled(CopyOp, copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params */ -#if !defined(__CUDACC_RTC__) template -CUTE_HOST +CUTE_HOST_RTC auto make_tma_copy(CopyOp const& copy_op, Tensor const& gtensor, @@ -946,7 +869,7 @@ make_tma_copy(CopyOp const& copy_op, template -CUTE_HOST +CUTE_HOST_RTC auto make_tma_copy(CopyOp const& copy_op, Tensor const& gtensor, @@ -959,7 +882,7 @@ template -CUTE_HOST +CUTE_HOST_RTC auto make_tma_copy(CopyOp const& copy_op, Tensor const& gtensor, @@ -968,6 +891,5 @@ make_tma_copy(CopyOp const& copy_op, { return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); } -#endif // !defined(__CUDACC_RTC__) } // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp new file mode 100644 index 0000000000..6d391b2173 --- /dev/null +++ b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/// @file copy_traits_sm90_tma_swizzle.hpp +/// @brief Functions for converting swizzle layout to TMA descriptor + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/swizzle_layout.hpp" + +namespace cute::detail { + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Swizzle) +{ + if constexpr (M == 4) { + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + case 3: return TMA::SmemSwizzleBits::B128; + case 2: return TMA::SmemSwizzleBits::B64; + case 1: return TMA::SmemSwizzleBits::B32; + case 0: return TMA::SmemSwizzleBits::DISABLE; + } + } else + { + static_assert(M < 0, "Unsupported layout swizzle."); + } +} + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Layout const& layout) +{ + return get_tma_swizzle_bits(get_swizzle_portion(layout)); +} + +} // namespace cute::detail diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index a9ca3660a4..844d653eeb 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -978,9 +978,9 @@ print_latex_mma(Shape_MNK const& shape_mnk, printf(latex_header); - constexpr int M = size<0>(shape_mnk); - constexpr int N = size<1>(shape_mnk); - constexpr int K = size<2>(shape_mnk); + auto M = size<0>(shape_mnk); + auto N = size<1>(shape_mnk); + auto K = size<2>(shape_mnk); // C starting at 0,0 bool c_filled[M][N] = {}; diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 7242e2d45a..56145934f7 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -127,7 +127,7 @@ mma_unpack(MMA_Traits const& traits, using RegTypeC = typename remove_extent::type; using MMATraits = MMA_Traits; - constexpr int RegNumD = extent::value; + [[maybe_unused]] constexpr int RegNumD = extent::value; constexpr int RegNumA = extent::value; constexpr int RegNumB = extent::value; constexpr int RegNumC = extent::value; @@ -186,6 +186,26 @@ mma_unpack(MMA_Traits const& traits, } } +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + mma_unpack(traits, D, A, B, C); +} + namespace detail { template diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 752023c2fa..993205c413 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -67,13 +67,13 @@ namespace GMMA { /////////////////////////////////////////// // M|N-major GMMA layouts in units of bits -using Layout_MN_INTER_Atom_Bits = Layout,Stride<_1,_128>>; +using Layout_MN_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _128>>>; using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; // K-major GMMA layouts in units of bits -using Layout_K_INTER_Atom_Bits = Layout,Stride<_128,_1>>; +using Layout_K_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _128,_1>>>; using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; @@ -101,20 +101,20 @@ using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_S // With GMMA::Major param template using Layout_INTER_Atom = typename conditional, - Layout_K_INTER_Atom>::type; + Layout_MN_INTER_Atom, + Layout_K_INTER_Atom>::type; template using Layout_SW32_Atom = typename conditional, - Layout_K_SW32_Atom>::type; + Layout_MN_SW32_Atom, + Layout_K_SW32_Atom>::type; template using Layout_SW64_Atom = typename conditional, - Layout_K_SW64_Atom>::type; + Layout_MN_SW64_Atom, + Layout_K_SW64_Atom>::type; template using Layout_SW128_Atom = typename conditional, - Layout_K_SW128_Atom>::type; + Layout_MN_SW128_Atom, + Layout_K_SW128_Atom>::type; // // Tensor to LayoutType utility @@ -208,14 +208,14 @@ make_gmma_desc(Tensor const& tensor) // Layout type constexpr GMMA::LayoutType LAYOUT_TYPE = GMMA::layout_type(u128_tensor); - desc.layout_type_ = uint8_t(LAYOUT_TYPE); + desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); // Start address (4LSB not included) uint32_t start_address = cast_smem_ptr_to_uint(u128_tensor.data().get()); - desc.start_address_ = start_address >> 4; + desc.bitfield.start_address_ = start_address >> 4; constexpr uint8_t base_offset = 0; - desc.base_offset_ = base_offset; + desc.bitfield.base_offset_ = base_offset; // LayoutType meta constexpr int W = LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE ? 1 : @@ -253,8 +253,8 @@ make_gmma_desc(Tensor const& tensor) constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); - desc.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; - desc.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; + desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == GMMA::LayoutType::INTERLEAVE) ? stride_11 : stride_01; } else if constexpr (MajorMode == GMMA::Major::K) { @@ -287,8 +287,8 @@ make_gmma_desc(Tensor const& tensor) // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); - desc.stride_byte_offset_ = stride_01; - desc.leading_byte_offset_ = stride_10; + desc.bitfield.stride_byte_offset_ = stride_01; + desc.bitfield.leading_byte_offset_ = stride_10; } else { static_assert(MajorMode != GMMA::Major::MN && MajorMode != GMMA::Major::K, "Unrecognized MajorMode!"); } @@ -330,19 +330,6 @@ struct DescriptorIterator CUTE_HOST_DEVICE constexpr DescriptorIterator operator+(Index const& offset) const { - // offset is in the units of uint128_t (4LSB of start_address not included) - - //GmmaDescriptor desc = desc_; - //desc.start_address_ += uint16_t(offset); - //desc.reg32_[0] += uint16_t(offset); // Generates better asm than adding to the bitfield - - // May need to update base_offset if swizzle alignment isn't guaranteed - //desc.base_offset_ = 0; - //assert((desc.start_address_ & 0b111000) == 0); // Assert base_offset is 0, generalize later - - //return {desc}; - - // The above seems to not work for some reason... return { GmmaDescriptor {desc_ + uint64_t(offset)} }; } @@ -3182,4 +3169,2755 @@ struct MMA_Traits //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e4m3_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e4m3_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = half_t; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = half_t; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementAFrg = GMMA::smem_desc; + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ElementDVal = float; + using ElementAVal = float_e5m2_t; + using ElementBVal = float_e5m2_t; + using ElementCVal = float; + + using ElementBFrg = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // end namespace cute diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 4b77ee1e37..4a12f1c584 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -40,6 +40,12 @@ # define CUTE_HOST inline #endif // CUTE_HOST_DEVICE, CUTE_DEVICE +#if defined(__CUDACC_RTC__) +# define CUTE_HOST_RTC CUTE_HOST_DEVICE +#else +# define CUTE_HOST_RTC CUTE_HOST +#endif + #if !defined(__CUDACC_RTC__) && (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) # define CUTE_UNROLL #pragma unroll # define CUTE_NO_UNROLL #pragma unroll 1 @@ -84,9 +90,7 @@ // It's harmless to use the macro for other GCC versions or other // compilers, but it has no effect. #if ! defined(CUTE_GCC_UNREACHABLE) -# if defined(__GNUC__) && __GNUC__ < 11 - // GCC 10, but not 7.5, 9.4.0, or 11, issues "missing return - // statement" warnings without this little bit of help. +# if defined(__clang__) || defined(__GNUC__) # define CUTE_GCC_UNREACHABLE __builtin_unreachable() # else # define CUTE_GCC_UNREACHABLE @@ -151,7 +155,6 @@ #include #include #include - // // Debugging utilities // diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp index 9e70e87f93..3b0831657a 100644 --- a/include/cute/container/array.hpp +++ b/include/cute/container/array.hpp @@ -194,6 +194,146 @@ struct array }; +template +struct array +{ + using value_type = T; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = value_type&; + using const_reference = const value_type&; + using pointer = value_type*; + using const_pointer = const value_type*; + using const_iterator = const_pointer; + using iterator = const_iterator; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return true; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + {} + + CUTE_HOST_DEVICE constexpr + void clear() + {} + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + {} +}; + template CUTE_HOST_DEVICE constexpr bool operator==(array const& lhs, array const& rhs) @@ -227,6 +367,22 @@ void swap(array& a, array& b) a.swap(b); } +/// @return A cute::array of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr cute::array +reverse(cute::array const& t) { + if constexpr (N == 0u) { + return t; + } + else { + cute::array t_r{}; + for (size_t k = 0; k < N; ++k) { + t_r[k] = t[N - k - 1]; + } + return t_r; + } +} + } // end cute @@ -274,7 +430,7 @@ namespace CUTE_STL_NAMESPACE template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -285,7 +441,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -310,7 +466,7 @@ struct tuple_element; template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -321,7 +477,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -330,5 +486,5 @@ struct tuple_element> using type = T; }; -} // end namepsace std +} // end namespace std #endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index e3fd8ee495..f39bcd66bf 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -37,325 +37,265 @@ #include -#include // sizeof_bits +#include // sizeof_bits #include +#include // dummy_type namespace cute { -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Statically sized array for any data type -template -class array_subbyte -{ - public: - - /// Number of total bits in the array - static constexpr int kSizeBits = sizeof_bits::value * N; - - /// Storage type - using Storage = conditional_t<(kSizeBits % 32) == 0, uint32_t, - conditional_t<(kSizeBits % 16) == 0, uint16_t, - uint8_t>>; - - /// Number of logical elements per stored object - static constexpr int kElementsPerStoredItem = sizeof_bits::value / sizeof_bits::value; - - /// Number of storage elements - static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; - - /// Bitmask for covering one item - static constexpr Storage bit_mask_ = ((Storage(1) << sizeof_bits::value) - 1); - - // - // C++ standard members with reference and iterator types omitted - // - - using value_type = T; - using pointer = value_type*; - using const_pointer = value_type const*; - - using size_type = size_t; - using difference_type = ptrdiff_t; - - // - // References - // - - /// Reference object inserts or extracts sub-byte items - class reference { - /// Pointer to storage element - Storage* ptr_; +// +// Underlying subbyte storage type +// +template +using subbyte_storage_type_t = conditional_t<(sizeof_bits_v <= 8), uint8_t, + conditional_t<(sizeof_bits_v <= 16), uint16_t, + conditional_t<(sizeof_bits_v <= 32), uint32_t, + conditional_t<(sizeof_bits_v <= 64), uint64_t, + conditional_t<(sizeof_bits_v <= 128), uint128_t, + dummy_type>>>>>; - /// Index into elements packed into Storage object - int idx_; +template +struct subbyte_iterator; - public: +// +// subbyte_reference +// Proxy object for sub-byte element references +// +template +struct subbyte_reference +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qulifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; - /// Default ctor - CUTE_HOST_DEVICE constexpr - reference() : ptr_(nullptr), idx_(0) {} + static_assert(!is_same_v, "Storage type is not supported"); - /// Ctor - CUTE_HOST_DEVICE constexpr - reference(Storage* ptr, int idx = 0) : ptr_(ptr), idx_(idx) {} + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); - /// Assignment - CUTE_HOST_DEVICE constexpr - reference& operator=(T x) { - Storage item = (x & bit_mask_); - Storage kUpdateMask = Storage(~(bit_mask_ << (idx_ * sizeof_bits::value))); - *ptr_ = Storage((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits::value))); - return *this; - } + // Number of logical elements per stored object + static constexpr uint8_t ElementsPerStoredItem = sizeof_bits_v / sizeof_bits_v; + // Bitmask for covering one item + static constexpr storage_type BitMask = storage_type((storage_type(1) << sizeof_bits_v) - 1); - CUTE_HOST_DEVICE constexpr - T get() const { - if constexpr (is_same::value) { - // Extract to bool -- potentially faster impl - return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); - } else { - // Extract to T - Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); - return reinterpret_cast(item); - } - } +private: - /// Extract to type T - CUTE_HOST_DEVICE constexpr - operator T() const { - return get(); - } - }; + friend class subbyte_iterator; + + // Pointer to storage element + storage_type* ptr_ = nullptr; - /// Reference object extracts sub-byte items - class const_reference { + // Index into elements packed into storage_type element. RI: 0 <= idx_ < ElementsPerStoredItem + uint8_t idx_ = 0; - /// Pointer to storage element - Storage const* ptr_; + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_reference(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) {} - /// Index into elements packed into Storage object - int idx_; +public: - public: + // Copy Ctor + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = element_type(other); + } - /// Default ctor - CUTE_HOST_DEVICE constexpr - const_reference(): ptr_(nullptr), idx_(0) { } + // Copy Assignment + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = element_type(other); + } - /// Ctor - CUTE_HOST_DEVICE constexpr - const_reference(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + // Dtor + ~subbyte_reference() = default; - CUTE_HOST_DEVICE constexpr - const T get() const { - if constexpr (is_same::value) { - // Extract to bool -- potentially faster impl - return bool((*ptr_) & (bit_mask_ << (idx_ * sizeof_bits::value))); - } else { - // Extract to T - Storage item = Storage((*ptr_ >> (idx_ * sizeof_bits::value)) & bit_mask_); - return reinterpret_cast(item); - } - } + // Assignment + template + CUTE_HOST_DEVICE constexpr + enable_if_t, subbyte_reference&> operator=(element_type x) { + static_assert(is_same_v, "Do not specify template arguments!"); + storage_type item = (reinterpret_cast(x) & BitMask); + storage_type kUpdateMask = storage_type(~(BitMask << (idx_ * sizeof_bits_v))); + *ptr_ = storage_type((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits_v))); + return *this; + } - /// Extract to type T - CUTE_HOST_DEVICE constexpr - operator T() const { - return get(); + CUTE_HOST_DEVICE + element_type get() const { + if constexpr (is_same_v) { // Extract to bool -- potentially faster impl + return bool((*ptr_) & (BitMask << (idx_ * sizeof_bits_v))); + } else { // Extract to element_type + storage_type item = storage_type((*ptr_ >> (idx_ * sizeof_bits_v)) & BitMask); + return reinterpret_cast(item); } - }; - - // - // Iterators - // - - /// Bidirectional iterator over elements - class iterator { - - /// Pointer to storage element - Storage* ptr_; + } - /// Index into elements packed into Storage object - int idx_; + // Extract to type element_type + CUTE_HOST_DEVICE constexpr + operator element_type() const { + return get(); + } +}; - public: - CUTE_HOST_DEVICE constexpr - iterator(): ptr_(nullptr), idx_(0) { } +// +// subbyte_iterator +// Random-access iterator over subbyte references +// +template +struct subbyte_iterator +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qulifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + // Reference proxy type + using reference = subbyte_reference; - CUTE_HOST_DEVICE constexpr - iterator(Storage* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + static_assert(!is_same_v, "Storage type is not supported"); - CUTE_HOST_DEVICE constexpr - iterator& operator++() { - ++idx_; - if (idx_ == kElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return *this; - } + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); - CUTE_HOST_DEVICE constexpr - iterator& operator--() { - if (idx_) { - --idx_; - } else { - --ptr_; - idx_ = kElementsPerStoredItem - 1; - } - return *this; - } + // Number of logical elements per stored object + static constexpr uint8_t ElementsPerStoredItem = sizeof_bits_v / sizeof_bits_v; - CUTE_HOST_DEVICE constexpr - iterator operator++(int) { - iterator ret(*this); - ++(*this); - return ret; - } +private: - CUTE_HOST_DEVICE constexpr - iterator operator--(int) { - iterator ret(*this); - --(*this); - return ret; - } + // Pointer to storage element + storage_type* ptr_ = nullptr; - CUTE_HOST_DEVICE constexpr - iterator& operator+=(int k) { - idx_ += k; - ptr_ += idx_ / kElementsPerStoredItem; - idx_ = idx_ % kElementsPerStoredItem; - return *this; - } + // Index into elements packed into storage_type element. RI: 0 <= idx_ < ElementsPerStoredItem + uint8_t idx_ = 0; - CUTE_HOST_DEVICE constexpr - iterator operator+(int k) const { - return iterator(ptr_,idx_) += k; - } +public: - CUTE_HOST_DEVICE constexpr - reference operator*() const { - return reference(ptr_, idx_); - } + template + CUTE_HOST_DEVICE constexpr + subbyte_iterator(PointerType* ptr, uint8_t idx = 0): ptr_(reinterpret_cast(ptr)), idx_(idx) { } - CUTE_HOST_DEVICE constexpr - reference operator[](int k) const { - return *(*this + k); + subbyte_iterator() = default; + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator++() { + ++idx_; + if (idx_ == ElementsPerStoredItem) { + ++ptr_; + idx_ = 0; } + return *this; + } - CUTE_HOST_DEVICE constexpr - bool operator==(iterator const& other) const { - return ptr_ == other.ptr_ && idx_ == other.idx_; + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator--() { + if (idx_) { + --idx_; + } else { + --ptr_; + idx_ = ElementsPerStoredItem - 1; } + return *this; + } - CUTE_HOST_DEVICE constexpr - bool operator!=(iterator const& other) const { - return !(*this == other); - } - }; + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator++(int) { + subbyte_iterator ret(*this); + ++(*this); + return ret; + } - /// Bidirectional constant iterator over elements - class const_iterator { + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator--(int) { + subbyte_iterator ret(*this); + --(*this); + return ret; + } - /// Pointer to storage element - Storage const* ptr_; + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator+=(uint64_t k) { + k += idx_; + ptr_ += k / ElementsPerStoredItem; + idx_ = k % ElementsPerStoredItem; + return *this; + } - /// Index into elements packed into Storage object - int idx_; + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator+(uint64_t k) const { + return subbyte_iterator(ptr_,idx_) += k; + } - public: + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return reference(ptr_, idx_); + } - CUTE_HOST_DEVICE constexpr - const_iterator(): ptr_(nullptr), idx_(0) { } + CUTE_HOST_DEVICE constexpr + reference operator[](uint64_t k) const { + return *(*this + k); + } - CUTE_HOST_DEVICE constexpr - const_iterator(Storage const* ptr, int idx = 0): ptr_(ptr), idx_(idx) { } + CUTE_HOST_DEVICE constexpr + friend bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; + } - CUTE_HOST_DEVICE constexpr - const_iterator& operator++() { - ++idx_; - if (idx_ == kElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return *this; - } + CUTE_HOST_DEVICE constexpr + friend bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { + return !(x == y); + } +}; - CUTE_HOST_DEVICE constexpr - const_iterator& operator--() { - if (idx_) { - --idx_; - } else { - --ptr_; - idx_ = kElementsPerStoredItem - 1; - } - return *this; - } +// +// array_subbyte +// Statically sized array for non-byte-aligned data types +// +template +struct array_subbyte +{ + using element_type = T; + using value_type = remove_cv_t; + using pointer = element_type*; + using const_pointer = element_type const*; - CUTE_HOST_DEVICE constexpr - const_iterator operator++(int) { - iterator ret(*this); - ++idx_; - if (idx_ == kElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return ret; - } + using size_type = size_t; + using difference_type = ptrdiff_t; - CUTE_HOST_DEVICE constexpr - const_iterator operator--(int) { - iterator ret(*this); - if (idx_) { - --idx_; - } else { - --ptr_; - idx_ = kElementsPerStoredItem - 1; - } - return ret; - } + // + // References + // + using reference = subbyte_reference; + using const_reference = subbyte_reference; - CUTE_HOST_DEVICE constexpr - const_iterator& operator+=(int k) { - idx_ += k; - ptr_ += idx_ / kElementsPerStoredItem; - idx_ = idx_ % kElementsPerStoredItem; - return *this; - } + // + // Iterators + // + using iterator = subbyte_iterator; + using const_iterator = subbyte_iterator; - CUTE_HOST_DEVICE constexpr - const_iterator operator+(int k) const { - return const_iterator(ptr_,idx_) += k; - } + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; - CUTE_HOST_DEVICE constexpr - const_reference operator*() const { - return const_reference(ptr_, idx_); - } + static_assert(!is_same_v, "Storage type is not supported"); - CUTE_HOST_DEVICE constexpr - const_reference operator[](int k) const { - return *(*this + k); - } + // Number of logical elements per stored object + static constexpr uint8_t ElementsPerStoredItem = sizeof_bits_v / sizeof_bits_v; - CUTE_HOST_DEVICE constexpr - bool operator==(iterator const& other) const { - return ptr_ == other.ptr_ && idx_ == other.idx_; - } + // Bitmask for covering one item + static constexpr storage_type BitMask = ((storage_type(1) << sizeof_bits::value) - 1); - CUTE_HOST_DEVICE constexpr - bool operator!=(iterator const& other) const { - return !(*this == other); - } - }; + // Number of storage elements + static constexpr size_type StorageElements = (N + ElementsPerStoredItem - 1) / ElementsPerStoredItem; private: - /// Internal storage - Storage storage[kStorageElements]; + // Internal storage + storage_type storage[StorageElements]; public: @@ -365,7 +305,7 @@ class array_subbyte CUTE_HOST_DEVICE constexpr array_subbyte(array_subbyte const& x) { CUTE_UNROLL - for (unsigned i = 0; i < kStorageElements; ++i) { + for (size_type i = 0; i < StorageElements; ++i) { storage[i] = x.storage[i]; } } @@ -385,40 +325,40 @@ class array_subbyte return !N; } - /// Efficient clear method + // Efficient clear method CUTE_HOST_DEVICE constexpr void clear() { CUTE_UNROLL - for (unsigned i = 0; i < kStorageElements; ++i) { - storage[i] = Storage(0); + for (size_type i = 0; i < StorageElements; ++i) { + storage[i] = storage_type(0); } } // Efficient fill method CUTE_HOST_DEVICE constexpr void fill(T const& value) { - Storage item = (reinterpret_cast(value) & bit_mask_); + storage_type item = (reinterpret_cast(value) & BitMask); // Reproduce the value over the bits of the storage item CUTE_UNROLL - for (unsigned s = sizeof_bits::value; s < sizeof_bits::value; s *= 2) { + for (size_type s = sizeof_bits_v; s < sizeof_bits_v; s *= 2) { item |= item << s; } CUTE_UNROLL - for (unsigned i = 0; i < kStorageElements; ++i) { + for (size_type i = 0; i < StorageElements; ++i) { storage[i] = item; } } CUTE_HOST_DEVICE constexpr reference at(size_type pos) { - return reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + return iterator(storage)[pos]; } CUTE_HOST_DEVICE constexpr const_reference at(size_type pos) const { - return const_reference(storage + pos / kElementsPerStoredItem, pos % kElementsPerStoredItem); + return const_iterator(storage)[pos]; } CUTE_HOST_DEVICE constexpr @@ -443,12 +383,12 @@ class array_subbyte CUTE_HOST_DEVICE constexpr reference back() { - return reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + return at(N-1); } CUTE_HOST_DEVICE constexpr const_reference back() const { - return const_reference(storage + kStorageElements - 1, kElementsPerStoredItem - 1); + return at(N-1); } CUTE_HOST_DEVICE constexpr @@ -462,12 +402,12 @@ class array_subbyte } CUTE_HOST_DEVICE constexpr - Storage* raw_data() { + storage_type* raw_data() { return storage; } CUTE_HOST_DEVICE constexpr - Storage const* raw_data() const { + storage_type const* raw_data() const { return storage; } @@ -488,12 +428,12 @@ class array_subbyte CUTE_HOST_DEVICE constexpr iterator end() { - return iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); + return iterator(storage + N / ElementsPerStoredItem, N % ElementsPerStoredItem); } CUTE_HOST_DEVICE constexpr const_iterator end() const { - return const_iterator(storage + N / kElementsPerStoredItem, N % kElementsPerStoredItem); + return const_iterator(storage + N / ElementsPerStoredItem, N % ElementsPerStoredItem); } CUTE_HOST_DEVICE constexpr @@ -525,8 +465,6 @@ void fill(array_subbyte& a, T const& value) a.fill(value); } -//////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace cute // @@ -573,7 +511,7 @@ namespace CUTE_STL_NAMESPACE template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -584,7 +522,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -609,7 +547,7 @@ struct tuple_element; template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -620,7 +558,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp index ab6d37dca3..3455a41620 100644 --- a/include/cute/container/tuple.hpp +++ b/include/cute/container/tuple.hpp @@ -642,7 +642,7 @@ namespace CUTE_STL_NAMESPACE template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -652,7 +652,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -680,7 +680,7 @@ struct tuple_element; template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -690,7 +690,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp index 4c6ddc09c9..41ff1e7d1e 100644 --- a/include/cute/container/type_list.hpp +++ b/include/cute/container/type_list.hpp @@ -80,7 +80,7 @@ namespace CUTE_STL_NAMESPACE template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -90,7 +90,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -114,7 +114,7 @@ struct tuple_element; template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -124,7 +124,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index b73e2ec7a7..7875ac1581 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -52,25 +52,38 @@ make_int_tuple(Ts const&... t) return {t...}; } -/** if rank(int) == 1, then get<0>(int) should work too - */ -template >::value)> -CUTE_HOST_DEVICE constexpr -decltype(auto) +// CuTe does not treat integers as tuples. +// For example, is_tuple is false, and tuple_size doesn't compile. +// Nevertheless, CuTe defines rank(Integral) as 1 +// (where "Integral" is a shorthand for either run-time integers +// or CuTe's compile-time integer constants), +// so therefore get<0>(Integral) just returns its input. +template >::value)> +CUTE_HOST_DEVICE constexpr decltype(auto) get(T&& t) noexcept { static_assert(I == 0, "Index out of range"); return static_cast(t); } -/** Custom recursive get for anything that implements get(.) - */ +// Custom recursive get for anything that implements get(.) (for a single integer I). template -CUTE_HOST_DEVICE constexpr -decltype(auto) +CUTE_HOST_DEVICE constexpr decltype(auto) get(Tuple&& t) noexcept { - return get(get(static_cast(t))); + using get_I0_result_t = cute::remove_cvref_t(static_cast(t)))>; + if constexpr (cute::is_integral::value) { + // Help MSVC deduce that the inner get(...) call is not a "local variable or temporary." + // The above if constexpr test repeats the constraint on the above get(T&&) overload. + // get<0, 0, ..., 0>(t) for cute::integral (either one of the built-in integer types like int, + // or one of CuTe's compile-time constant types) t, and for one or more zeros, just returns t. + static_assert(I1 == 0, "Index I1 is out of range"); + static_assert(((Is == 0) && ...), "At least one index in Is is out of range"); + return get(static_cast(t)); + } + else { + return get(get(static_cast(t))); + } } // @@ -173,6 +186,26 @@ min(T0 const& t0, Ts const&... ts) CUTE_GCC_UNREACHABLE; } +// +// gcd +// + +template +CUTE_HOST_DEVICE constexpr +auto +gcd(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::gcd(cute::apply(t0, [](auto const&... a){ return cute::gcd(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::gcd(t0, cute::gcd(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + // // depth // @@ -219,12 +252,22 @@ product(IntTuple const& a) CUTE_GCC_UNREACHABLE; } +// Return a rank(t) tuple @a result such that get(@a result) = product(get(@a t)) template CUTE_HOST_DEVICE constexpr auto product_each(Tuple const& t) { - return transform(t, [](auto const& x) { return product(x); }); + return transform(wrap(t), [](auto const& x) { return product(x); }); +} + +// Take the product of Tuple at the leaves of TupleG +template +CUTE_HOST_DEVICE constexpr +auto +product_like(Tuple const& tuple, TupleG const& guide) +{ + return transform_leaf(guide, tuple, [](auto const& g, auto const& t) { return product(t); }); } // Return the product of elements in a mode @@ -347,6 +390,25 @@ shape_div(constant const&, constant const&) return {}; } +/** Minimum for Shapes + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_min(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value || is_tuple::value) { + static_assert(dependent_false, "Not implemented."); + } else + if constexpr (is_constant<1, IntTupleA>::value || is_constant<1, IntTupleB>::value) { + return Int<1>{}; // _1 is less than all other shapes, preserve static + } else { + return cute::min(a, b); + } + + CUTE_GCC_UNREACHABLE; +} + /** Return a tuple the same profile as A scaled by corresponding elements in B */ template @@ -371,7 +433,7 @@ auto congruent(IntTupleA const& a, IntTupleB const& b) { return bool_constant::value>{}; + decltype(repeat_like(shape(b),_0{}))>::value>{}; } template diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index cdbbb5ace1..5b81cfd833 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -367,7 +367,7 @@ CUTE_HOST_DEVICE constexpr auto make_ordered_layout(Shape const& shape, Order const& order) { - static_assert(is_static::value && is_static::value); + static_assert(is_static::value); return make_layout(shape, compact_order(shape, order)); } @@ -464,6 +464,19 @@ take(Layout const& layout) return make_layout(take(layout.shape()), take(layout.stride())); } +// +// Select layout modes according to an index sequence. +// + +template +CUTE_HOST_DEVICE constexpr +auto +select(Layout const& layout) +{ + return make_layout(select(layout.shape()), + select(layout.stride())); +} + template CUTE_HOST_DEVICE constexpr auto @@ -472,6 +485,15 @@ flatten(Layout const& layout) return make_layout(flatten(layout.shape()), flatten(layout.stride())); } +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(Layout const& layout, TargetProfile const& target_profile) +{ + return make_layout(unflatten(layout.shape(), target_profile), + unflatten(layout.stride(), target_profile)); +} + // // Utilities // @@ -552,18 +574,33 @@ depth(Layout const& layout) return depth(shape(layout)); } -// Return the codomain size of a mode -// @return M smallest integer such that @a sub_layout(c) < M for all c < size(@a sub_layout) +// Return the codomain shape of a mode +// @post size(coshape(@a a)) == cosize(@a a) +// @return C Coordinate with smallest elements such that that +// @a elem_less(sub_layout(c), C) for all c < size(@a sub_layout) // where sub_layout = get(layout). template CUTE_HOST_DEVICE constexpr auto -cosize(Layout const& layout) +coshape(Layout const& layout) { // Protect against negative strides auto abs_sub_layout = make_layout(shape(layout), transform_leaf(stride(layout), abs_fn{})); - return abs_sub_layout(size(abs_sub_layout) - Int<1>{}) + Int<1>{}; + auto co_coord = as_arithmetic_tuple(abs_sub_layout(size(abs_sub_layout) - Int<1>{})); + return co_coord + repeat_like(co_coord, Int<1>{}); +} + +// Return the codomain size of a mode +// @return M smallest integer such that +// @a sub_layout(c) < M for all c < size(@a sub_layout) +// where sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +cosize(Layout const& layout) +{ + return size(coshape(layout)); } template @@ -622,6 +659,16 @@ dice(Coord const& c, Layout const& layout) dice(c, layout.stride())); } +// Compute a pointer offset and (potentially modified) layout from a coordinate +// This exists so it can be overloaded for ComposedLayout +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Layout const& layout) +{ + return cute::make_tuple(layout, layout(coord)); +} + // // Transform the modes of a layout // @@ -794,6 +841,16 @@ append(Layout const& layout, append(layout.stride(), x.stride())); } +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + template CUTE_HOST_DEVICE constexpr auto @@ -804,6 +861,16 @@ prepend(Layout const& layout, prepend(layout.stride(), x.stride())); } +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + template CUTE_HOST_DEVICE constexpr auto @@ -836,16 +903,16 @@ template CUTE_HOST_DEVICE constexpr auto -composition(Layout const& lhs, - RShape const& rhs_shape, RStride const& rhs_stride) +composition_impl(Layout const& lhs, + RShape const& rhs_shape, RStride const& rhs_stride) { if constexpr (is_tuple::value) { // Apply the right-distributivity of Layout composition - return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition(lhs, s, d); }); + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition_impl(lhs, s, d); }); } else if constexpr (is_scaled_basis::value) { // Special case for a ScaledBasis stride - return composition(get(lhs), rhs_shape, rhs_stride.value()); + return composition_impl(get(lhs), rhs_shape, rhs_stride.value()); } else if constexpr (is_integral::value) { // Integral Rstride (and RShape) @@ -871,7 +938,7 @@ composition(Layout const& lhs, // Mod out the rhs_shape from the lhs.shape() auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); }); // Jump into coalesce and append (rest_shape, get(lhs.stride()) @@ -892,9 +959,9 @@ composition(Layout const& lhs, auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); // Mod out the rhs_shape from the lhs.shape() - auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), + auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), cute::min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); }); // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) @@ -914,8 +981,8 @@ auto composition(Layout const& lhs, Layout const& rhs) { - //return detail::composition(flatten(lhs), rhs.shape(), rhs.stride()); - return detail::composition(lhs, rhs.shape(), rhs.stride()); + //return detail::composition_impl(flatten(lhs), rhs.shape(), rhs.stride()); + return detail::composition_impl(lhs, rhs.shape(), rhs.stride()); } template @@ -968,10 +1035,7 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) // Should just be a sort and a fold... // Then we could even handle dynamic strides (but they would destroy all static strides) auto result = fold(make_seq{}, - cute::make_tuple(shape, - stride, - cute::make_tuple(), - cute::make_tuple(Int<1>{})), + cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), [](auto const& init, auto i) { auto curr_stride = cute::min(get<1>(init)); @@ -979,9 +1043,9 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) auto curr_shape = get(get<0>(init)); return cute::make_tuple(remove(get<0>(init)), // Remove the curr shape - remove(get<1>(init)), // Remove the curr stride - append(get<2>(init), curr_stride / get<3,i>(init)), // new shape = curr_stride / last_stride - append(get<3>(init), curr_shape * curr_stride)); // new stride = curr_shape * curr_stride + remove(get<1>(init)), // Remove the curr stride + append(get<2>(init), curr_stride / get<3,i>(init)), // new shape = curr_stride / last_stride + append(get<3>(init), curr_shape * curr_stride)); // new stride = curr_shape * curr_stride }); // Append the last shape mode @@ -1025,22 +1089,24 @@ complement(Layout const& layout) namespace detail { -template +template CUTE_HOST_DEVICE constexpr auto inverse_seq(Shape const& shape, Stride const& stride, seq) { - if constexpr (I == decltype(rank(stride))::value) { + auto next_I = find_if(stride, [](auto a) { return is_constant{}; }); + + if constexpr (next_I == decltype(rank(stride))::value) { return seq{}; } else { - //auto next_stride = get(shape) * get(stride); - using next_stride = decltype(get(shape) * get(stride)); // NOTE: WAR for g++-7 + // auto next_stride = get(shape) * get(stride); + // NOTE: Needed for g++-7 + using next_stride = decltype(get(shape) * get(stride)); if constexpr (is_static::value) { - auto next_idx = find_if(stride, [](auto a) { return is_constant{}; }); - return inverse_seq(shape, stride, seq{}); + return inverse_seq(shape, stride, seq{}); } else { - return seq{}; + return seq{}; } } @@ -1066,11 +1132,10 @@ right_inverse(Layout const& layout) auto flat_layout = coalesce(layout); auto astride = transform_leaf(flat_layout.stride(), abs_fn{}); - // Find Int<1>{}, the starting idx, and follow the strides to gen inverse_seq - auto next_I = find_if(astride, [](auto a) { return is_constant<1, decltype(a)>{}; }); - [[maybe_unused]] auto iseq = detail::inverse_seq(flat_layout.shape(), astride, seq<>{}); + // Find Int<1>{}, the starting stride, and follow the strides to gen inverse_seq + [[maybe_unused]] auto iseq = detail::inverse_seq<1>(flat_layout.shape(), astride, seq<>{}); - if constexpr (tuple_size::value == 0) { + if constexpr (iseq.size() == 0) { return Layout<_1,_0>{}; // Empty case, nothing found } else { // Generate the corresponding new strides and construct @@ -1150,8 +1215,6 @@ max_common_layout(Layout const& a, // (i.e. are large and multiples of the vector) return Layout<_1,_0>{}; } - - CUTE_GCC_UNREACHABLE; } /* Return Int such that N is the maximum number of contiguous elements @@ -1168,7 +1231,78 @@ auto max_common_vector(Layout const& a, Layout const& b) { - return size(max_common_layout(a, b)); + if constexpr (is_static::value && is_static::value && + is_static::value && is_static::value) + { + Layout common = coalesce(composition(a, right_inverse(b))); + + if constexpr (is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return shape<0>(common); + } else { + return Int<1>{}; + } + } else { + // CASE: One of the layouts is dynamic, can't prove alignment+vectorization is valid + // NOTE: Could weaken if we assume dynamic shapes/strides obey alignment requirements + // (i.e. are large and multiples of the vector) + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Kernel (Nullspace) of a Layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace_seq(Stride const& stride, seq) +{ + if constexpr (NextI == rank_v) { + return seq{}; + } else + if constexpr (is_constant<0, decltype(get(stride))>::value) { + return detail::nullspace_seq(stride, seq{}); + } else { + return detail::nullspace_seq(stride, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// +// Build the nullspace of a layout +// @result A layout @a result such that +// size(@a result) == size(@a layout) / size(filter(@a layout)) +// @a layout(@a result(i)) == 0 for all i < size(@a result) +// + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(Layout const& layout) +{ + auto flat_layout = flatten(layout); + + auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{}); + + if constexpr (iseq.size() == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto rstride = compact_col_major(flat_layout.shape()); + return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; } // diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp new file mode 100644 index 0000000000..7b3b6f4f68 --- /dev/null +++ b/include/cute/layout_composed.hpp @@ -0,0 +1,609 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +/* This implements a ComposedLayout of the form + * LayoutA o Offset o LayoutB + * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. + * For example, then the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB). + * + * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. + * For example, this layout provides shape() and size() functions, but does not provide stride() functions. + * Mostly, the similar functionality is accomplished by applying each operation to LayoutB only + * as LayoutB defines the domain. + */ + +namespace cute +{ + +// A Layout of non-trivially composable functions: F o I o L +template +struct ComposedLayout : private cute::tuple // EBO for static layouts +{ + CUTE_HOST_DEVICE constexpr + ComposedLayout(LayoutA const& layoutA = {}, + Offset const& offset = {}, + LayoutB const& layoutB = {}) + : cute::tuple(layoutA, offset, layoutB) + {} + + // + // Accessors + // + + static constexpr int rank = LayoutB::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_a() const { + return get<0>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + offset() const { + return get<1>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_b() const { + return get<2>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout_b().shape(); + } + + // Doesn't really make sense to ask for the strides of this "layout" + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const = delete; + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return layout_a()(offset() + layout_b()(coord)); // (A o O o B)(c) + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } +}; + +template +struct is_layout> : true_type {}; + +template +struct is_composed_layout : false_type {}; +template +struct is_composed_layout> : true_type {}; + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_composed_layout(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +// +// Utilities +// + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(ComposedLayout const& clayout) +{ + return composition(clayout.layout_a(), clayout.offset(), layout(clayout.layout_b())); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(ComposedLayout const& layout) +{ + return shape(layout.layout_b()); +} + +// Doesn't make sense to directly ask for the strides of this "layout" +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(ComposedLayout const& layout) = delete; + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(ComposedLayout const& layout) +{ + return size(layout.layout_b()); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(ComposedLayout const& layout) +{ + return rank(layout.layout_b()); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(ComposedLayout const& layout) +{ + return depth(layout.layout_b()); +} + +// Return the codomain size of a mode +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout const& layout) +{ + return cosize(layout.layout_b()); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +template +CUTE_HOST_DEVICE constexpr +auto +get(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), get(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), take(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), flatten(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(ComposedLayout const& a, X const& x) +{ + return composition(a.layout_a(), a.offset(), append(a.layout_b(), x)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), group(a.layout_b())); +} + +// +// Slice a ComposedLayout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout const& layout) +{ + auto [slice, offset] = slice_and_offset(coord, layout.layout_b()); + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + offset, slice}, Int<0>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& coord, ComposedLayout const& layout) +{ + return get<0>(slice_and_offset(coord, layout)); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// For composed layout tensors the offset is accumulated in the layout itself while pointer is not updated +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, ComposedLayout const& layout) +{ + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + layout.layout_b()(coord), layout.layout_b()}, Int<0>{}); +} + +// +// composition +// + +template +CUTE_HOST_DEVICE constexpr +auto +composition(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(ComposedLayout const& a, + LayoutOrTile const& b) +{ + return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + ComposedLayout const& b) +{ + CUTE_STATIC_ASSERT_V(b.offset() == Int<0>{}, "Require offset == 0."); + + return composition(composition(a, b.layout_a()), b.layout_b()); +} + +// +// complement +// + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) +{ + return complement(layout.layout_b(), cosize_hi); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout) +{ + return complement(layout, cosize(layout)); +} + +// +// inverse +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout const& layout) +{ + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout const& layout) +{ + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), zip(a.layout_b())); +} + +// Partitions + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(ComposedLayout const& a, + Tile const& b) +{ + return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(ComposedLayout const& layout, + Shape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + return composition(layout.layout_a(), layout.offset(), tile_to_shape(layout.layout_b(), trg_shape, ord_shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), filter(layout.layout_b(), trg_profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile)); +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout const& layout) +{ + return composition(upcast(layout.layout_a()), upcast(layout.offset()), upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout const& layout) +{ + return composition(downcast(layout.layout_a()), downcast(layout.offset()), downcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(ComposedLayout const& layout) +{ + if constexpr (sizeof(NewType) == sizeof(OldType)) { + return layout; + } else if constexpr (sizeof(NewType) > sizeof(OldType)) { + static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); + return upcast(layout); + } else if constexpr (sizeof(NewType) < sizeof(OldType)) { + static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); + return downcast(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ComposedLayout const& layout) +{ + print(layout.layout_a()); print(" o "); print(layout.offset()); print(" o "); print(layout.layout_b()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) +{ + return os << layout.layout_a() << " o " << layout.offset() << " o " << layout.layout_b(); +} +#endif + +} // end namespace cute diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index a7ce47f12c..c2c73be7d8 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -77,6 +77,20 @@ as_arithmetic_tuple(tuple const& t) { return ArithmeticTuple(t); } +template ::value)> +CUTE_HOST_DEVICE constexpr +T const& +as_arithmetic_tuple(T const& t) { + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ArithmeticTuple const& t) { + return t; +} + // // Numeric operators // @@ -110,18 +124,26 @@ operator+(tuple const& t, ArithmeticTuple const& u) { // Special cases // -template +template CUTE_HOST_DEVICE constexpr auto -operator+(constant, ArithmeticTuple const& u) { - return u; +operator+(C, ArithmeticTuple const& u) { + if constexpr (t == 0) { + return u; + } else { + static_assert(t == 0, "Artihmetic tuple op+ error!"); + } } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ArithmeticTuple const& t, constant) { - return t; +operator+(ArithmeticTuple const& t, C) { + if constexpr (u == 0) { + return t; + } else { + static_assert(u == 0, "Artihmetic tuple op+ error!"); + } } // @@ -159,11 +181,9 @@ CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) { // // ArithmeticTuple "basis" elements -// - -// Abstract value: -// A ScaledBasis is a (at least) rank-N0 ArithmeticTuple: +// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: // (_0,_0,...,T,_0,...) +// with value T in the Nth mode template struct ScaledBasis : private tuple @@ -188,16 +208,30 @@ struct is_scaled_basis> : true_type {}; template struct is_integral> : true_type {}; -template +// Get the scalar T out of a ScaledBasis +template CUTE_HOST_DEVICE constexpr auto -basis_value(T const& e) { - return e; +basis_value(SB const& e) +{ + if constexpr (is_scaled_basis::value) { + return basis_value(e.value()); + } else { + return e; + } + CUTE_GCC_UNREACHABLE; } -template +// Apply the N... pack to another Tuple +template CUTE_HOST_DEVICE constexpr auto -basis_value(ScaledBasis const& e) { - return basis_value(e.value()); +basis_get(SB const& e, Tuple const& t) +{ + if constexpr (is_scaled_basis::value) { + return basis_get(e.value(), get(t)); + } else { + return t; + } + CUTE_GCC_UNREACHABLE; } namespace detail { @@ -217,6 +251,14 @@ struct Basis { } // end namespace detail +// Shortcut for writing ScaledBasis, N0>, N1>, ...> +// E<> := _1 +// E<0> := (_1,_0,_0,...) +// E<1> := (_0,_1,_0,...) +// E<0,0> := ((_1,_0,_0,...),_0,_0,...) +// E<0,1> := ((_0,_1,_0,...),_0,_0,...) +// E<1,0> := (_0,(_1,_0,_0,...),_0,...) +// E<1,1> := (_0,(_0,_1,_0,...),_0,...) template using E = typename detail::Basis::type; @@ -248,6 +290,15 @@ as_arithmetic_tuple(ScaledBasis const& t) { return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); } +// Turn a ScaledBases into a rank-N ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return as_arithmetic_tuple(t); +} + // Turn an ArithmeticTuple into a rank-M ArithmeticTuple // with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) template @@ -258,7 +309,24 @@ as_arithmetic_tuple(ArithmeticTuple const& t) { return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); } -// Return... +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(ScaledBasis const& b, U const& u) +{ + auto t = safe_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(ScaledBasis const& b, U const& u) +{ + auto t = shape_div(b.value(), u); + return ScaledBasis{t}; +} + template CUTE_HOST_DEVICE constexpr auto @@ -266,12 +334,21 @@ make_basis_like(Shape const& shape) { if constexpr (is_integral::value) { return Int<1>{}; - } else { + } + else { // Generate bases for each rank of shape - return transform(tuple_seq{}, [&](auto I) { - // Generate bases for each rank of shape_i and add an i on front - constexpr int i = decltype(I)::value; // NOTE: nvcc workaround - return transform_leaf(make_basis_like(get(shape)), [&](auto e) { return ScaledBasis{}; }); + return transform(tuple_seq{}, shape, [](auto I, auto si) { + // Generate bases for each rank of si and add an i on front + using I_type = decltype(I); + return transform_leaf(make_basis_like(si), [](auto e) { + // MSVC has trouble capturing variables as constexpr, + // so that they can be used as template arguments. + // This is exactly what the code needs to do with i, unfortunately. + // The work-around is to define i inside the inner lambda, + // by using just the type from the enclosing scope. + constexpr int i = I_type::value; + return ScaledBasis{}; + }); }); } @@ -279,25 +356,34 @@ make_basis_like(Shape const& shape) } // Equality -template +template CUTE_HOST_DEVICE constexpr auto -operator==(ScaledBasis, Int) { - return false_type{}; +operator==(ScaledBasis const& t, ScaledBasis const& u) { + return bool_constant{} && t.value() == u.value(); } -template +// Not equal to anything else +template CUTE_HOST_DEVICE constexpr -auto -operator==(Int, ScaledBasis) { - return false_type{}; +false_type +operator==(ScaledBasis const&, U const&) { + return {}; } -template +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(T const&, ScaledBasis const&) { + return {}; +} + +// Abs +template CUTE_HOST_DEVICE constexpr auto -operator==(ScaledBasis const& t, ScaledBasis const& u) { - return bool_constant{} && t.value() == u.value(); +abs(ScaledBasis const& e) { + return ScaledBasis{abs(e.value())}; } // Multiplication @@ -306,7 +392,8 @@ template const& e) { - return ScaledBasis{a*e.value()}; + auto r = a * e.value(); + return ScaledBasis{r}; } template const& e, B const& b) { - return ScaledBasis{e.value()*b}; + auto r = e.value() * b; + return ScaledBasis{r}; } // Addition @@ -334,6 +422,22 @@ operator+(ArithmeticTuple const& t, ScaledBasis const& u) { return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, tuple const& u) { + constexpr int R = cute::max(N+1, int(sizeof...(U))); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(tuple const& t, ScaledBasis const& u) { + constexpr int R = cute::max(int(sizeof...(T)), M+1); + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + template CUTE_HOST_DEVICE constexpr auto @@ -342,18 +446,26 @@ operator+(ScaledBasis const& t, ScaledBasis const& u) { return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(constant, ScaledBasis const& u) { - return u; +operator+(C, ScaledBasis const& u) { + if constexpr (t == 0) { + return u; + } else { + static_assert(t == 0, "ScaledBasis op+ error!"); + } } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, constant) { - return t; +operator+(ScaledBasis const& t, C) { + if constexpr (u == 0) { + return t; + } else { + static_assert(u == 0, "ScaledBasis op+ error!"); + } } // @@ -380,7 +492,7 @@ namespace CUTE_STL_NAMESPACE template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -390,7 +502,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -414,7 +526,7 @@ struct tuple_element; template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template @@ -424,7 +536,7 @@ struct tuple_element> template struct tuple_size> - : cute::integral_constant + : CUTE_STL_NAMESPACE::integral_constant {}; template diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index e2b2988491..9be920d1b0 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -121,14 +121,25 @@ template struct sizeof_bits { static constexpr size_t value = sizeof(T) * 8; }; + +template +struct sizeof_bits: sizeof_bits {}; + +template <> +struct sizeof_bits { + static constexpr size_t value = 0; +}; + template <> struct sizeof_bits { static constexpr size_t value = 1; }; + template struct sizeof_bits> { static constexpr size_t value = Bits; }; + template static constexpr int sizeof_bits_v = sizeof_bits::value; diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index a0b5b07519..bb165111f0 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -38,20 +38,24 @@ namespace cute { -template -struct constant : CUTE_STL_NAMESPACE::integral_constant { - static constexpr T value = v; - using value_type = T; - using type = constant; +// Short name for fast compilation +template +struct C { + using type = C; + static constexpr auto value = v; + using value_type = decltype(v); CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } }; template -using integral_constant = constant; +using constant = C; + +template +using integral_constant = C; template -using bool_constant = constant; +using bool_constant = C; using true_type = bool_constant; using false_type = bool_constant; @@ -60,40 +64,43 @@ using false_type = bool_constant; // Traits // -// Use std::is_integral to match built-in integral types (int, int64_t, unsigned, etc) +// Use cute::is_std_integral to match built-in integral types (int, int64_t, unsigned, etc) // Use cute::is_integral to match both built-in integral types AND constant template -struct is_integral : bool_constant::value> {}; -template -struct is_integral> : true_type {}; +struct is_integral : bool_constant::value> {}; +template +struct is_integral> : true_type {}; // is_static detects if an (abstract) value is defined completely by it's type (no members) template struct is_static : bool_constant::value> {}; +template +constexpr bool is_static_v = is_static::value; + // is_constant detects if a type is a constant and if v is equal to a value template struct is_constant : false_type {}; -template -struct is_constant > : bool_constant {}; -template -struct is_constant const > : bool_constant {}; -template -struct is_constant const&> : bool_constant {}; -template -struct is_constant &> : bool_constant {}; -template -struct is_constant &&> : bool_constant {}; +template +struct is_constant > : bool_constant {}; +template +struct is_constant const > : bool_constant {}; +template +struct is_constant const&> : bool_constant {}; +template +struct is_constant &> : bool_constant {}; +template +struct is_constant &&> : bool_constant {}; // // Specializations // template -using Int = constant; +using Int = C; using _m32 = Int<-32>; using _m24 = Int<-24>; @@ -146,25 +153,21 @@ using _524288 = Int<524288>; /***************/ #define CUTE_LEFT_UNARY_OP(OP) \ - template \ + template \ CUTE_HOST_DEVICE constexpr \ - constant \ - operator OP (constant) { \ + C<(OP t)> operator OP (C) { \ return {}; \ } #define CUTE_RIGHT_UNARY_OP(OP) \ - template \ + template \ CUTE_HOST_DEVICE constexpr \ - constant \ - operator OP (constant) { \ + C<(t OP)> operator OP (C) { \ return {}; \ } - #define CUTE_BINARY_OP(OP) \ - template \ + template \ CUTE_HOST_DEVICE constexpr \ - constant \ - operator OP (constant, constant) { \ + C<(t OP u)> operator OP (C, C) { \ return {}; \ } @@ -203,99 +206,91 @@ CUTE_BINARY_OP(<=); // Mixed static-dynamic special cases // -template ::value)> -CUTE_HOST_DEVICE constexpr -constant -operator*(constant, U) { - return {}; -} - -template ::value)> +template ::value && t == 0)> CUTE_HOST_DEVICE constexpr -constant -operator*(U, constant) { +C<0> +operator*(C, U) { return {}; } -template ::value)> +template ::value && t == 0)> CUTE_HOST_DEVICE constexpr -constant -operator/(constant, U) { +C<0> +operator*(U, C) { return {}; } -template ::value)> +template ::value && t == 0)> CUTE_HOST_DEVICE constexpr -constant -operator%(U, constant) { +C<0> +operator/(C, U) { return {}; } -template ::value)> +template ::value && (t == 1 || t == -1))> CUTE_HOST_DEVICE constexpr -constant -operator%(U, constant) { +C<0> +operator%(U, C) { return {}; } -template ::value)> +template ::value && t == 0)> CUTE_HOST_DEVICE constexpr -constant -operator%(constant, U) { +C<0> +operator%(C, U) { return {}; } -template ::value)> +template ::value && t == 0)> CUTE_HOST_DEVICE constexpr -constant -operator&(constant, U) { +C<0> +operator&(C, U) { return {}; } -template ::value)> +template ::value && t == 0)> CUTE_HOST_DEVICE constexpr -constant -operator&(U, constant) { +C<0> +operator&(U, C) { return {}; } -template ::value && !bool(t))> +template ::value && !bool(t))> CUTE_HOST_DEVICE constexpr -constant -operator&&(constant, U) { +C +operator&&(C, U) { return {}; } -template ::value && !bool(t))> +template ::value && !bool(t))> CUTE_HOST_DEVICE constexpr -constant -operator&&(U, constant) { +C +operator&&(U, C) { return {}; } -template ::value && bool(t))> +template ::value && bool(t))> CUTE_HOST_DEVICE constexpr -constant -operator||(constant, U) { +C +operator||(C, U) { return {}; } -template ::value && bool(t))> +template ::value && bool(t))> CUTE_HOST_DEVICE constexpr -constant -operator||(U, constant) { +C +operator||(U, C) { return {}; } @@ -304,34 +299,27 @@ operator||(U, constant) { // #define CUTE_NAMED_UNARY_FN(OP) \ - template \ + template \ CUTE_HOST_DEVICE constexpr \ - constant \ - OP (constant) { \ + C OP (C) { \ return {}; \ } - #define CUTE_NAMED_BINARY_FN(OP) \ - template \ + template \ CUTE_HOST_DEVICE constexpr \ - constant \ - OP (constant, constant) { \ + C OP (C, C) { \ return {}; \ } \ - \ - template ::value)> \ + template ::value)> \ CUTE_HOST_DEVICE constexpr \ - auto \ - OP (constant, U u) { \ + auto OP (C, U u) { \ return OP(t,u); \ } \ - \ - template ::value)> \ + template ::value)> \ CUTE_HOST_DEVICE constexpr \ - auto \ - OP (T t, constant) { \ + auto OP (T t, C) { \ return OP(t,u); \ } @@ -353,32 +341,30 @@ CUTE_NAMED_BINARY_FN(lcm); // Other functions // -template +template CUTE_HOST_DEVICE constexpr -constant -safe_div(constant, constant) { +C +safe_div(C, C) { static_assert(t % u == 0, "Static safe_div requires t % u == 0"); return {}; } -template ::value)> +template ::value)> CUTE_HOST_DEVICE constexpr auto -safe_div(constant, U u) { +safe_div(C, U u) { return t / u; } -template ::value)> +template ::value)> CUTE_HOST_DEVICE constexpr auto -safe_div(T t, constant) { +safe_div(T t, C) { return t / u; } -// cute::true_type prefers standard conversion to std::true_type -// over user-defined conversion to bool template CUTE_HOST_DEVICE constexpr decltype(auto) @@ -386,8 +372,6 @@ conditional_return(true_type, TrueType&& t, FalseType&&) { return static_cast(t); } -// cute::false_type prefers standard conversion to std::false_type -// over user-defined conversion to bool template CUTE_HOST_DEVICE constexpr decltype(auto) @@ -419,15 +403,15 @@ conditional_return(TrueType const& t, FalseType const& f) { // Display utilities // -template -CUTE_HOST_DEVICE void print(integral_constant const&) { - printf("_%d", N); +template +CUTE_HOST_DEVICE void print(C const&) { + printf("_%d", int(t)); } #if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, integral_constant const&) { - return os << "_" << N; +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, C const&) { + return os << "_" << t; } #endif diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index a90716a64a..ec46fd79a4 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -79,8 +79,8 @@ abs(T const& t) { // Greatest common divisor of two integers template ::value && - CUTE_STL_NAMESPACE::is_integral::value)> + __CUTE_REQUIRES(is_std_integral::value && + is_std_integral::value)> CUTE_HOST_DEVICE constexpr auto gcd(T t, U u) { @@ -94,8 +94,8 @@ gcd(T t, U u) { // Least common multiple of two integers template ::value && - CUTE_STL_NAMESPACE::is_integral::value)> + __CUTE_REQUIRES(is_std_integral::value && + is_std_integral::value)> CUTE_HOST_DEVICE constexpr auto lcm(T const& t, U const& u) { @@ -301,8 +301,8 @@ signum(T const& x) { // @pre t % u == 0 // @result t / u template ::value && - CUTE_STL_NAMESPACE::is_integral::value)> + __CUTE_REQUIRES(is_std_integral::value && + is_std_integral::value)> CUTE_HOST_DEVICE constexpr auto safe_div(T const& t, U const& u) { diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index da32784fb3..479ad699b5 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -51,6 +51,13 @@ template struct has_dereference())>> : true_type { }; +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(T* ptr) { + return ptr; +} + // // Pointer categories // @@ -92,13 +99,20 @@ struct device_ptr CUTE_HOST_DEVICE constexpr friend ptrdiff_t operator-(device_ptr const& a, - device_ptr const& b) { + device_ptr const& b) { return a.ptr_ - b.ptr_; } T* ptr_; }; +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(device_ptr ptr) { + return ptr.get(); +} + // // gmem_ptr // @@ -122,6 +136,24 @@ make_gmem_ptr(void* ptr) { return {reinterpret_cast(ptr)}; } +template +CUTE_HOST_DEVICE constexpr +gmem_ptr +make_gmem_ptr(void const* ptr) { + return {reinterpret_cast(ptr)}; +} + +// nullptr_t overloads are needed because otherwise, +// make_gmem_ptr(nullptr) will be ambiguous, +// as std::nullptr_t can be converted to any pointer +// or pointer to member type. +template +CUTE_HOST_DEVICE constexpr +gmem_ptr +make_gmem_ptr(decltype(nullptr)) { // nullptr_t + return {static_cast(nullptr)}; +} + template struct is_gmem> : true_type {}; @@ -148,6 +180,13 @@ make_smem_ptr(void* ptr) { return {reinterpret_cast(ptr)}; } +template +CUTE_HOST_DEVICE constexpr +smem_ptr +make_smem_ptr(void const* ptr) { + return {reinterpret_cast(ptr)}; +} + template struct is_smem> : true_type {}; @@ -174,6 +213,13 @@ make_rmem_ptr(void* ptr) { return {reinterpret_cast(ptr)}; } +template +CUTE_HOST_DEVICE constexpr +rmem_ptr +make_rmem_ptr(void const* ptr) { + return {reinterpret_cast(ptr)}; +} + template struct is_rmem> : true_type {}; diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp index 515bb7b3b5..06d4b97755 100644 --- a/include/cute/stride.hpp +++ b/include/cute/stride.hpp @@ -273,7 +273,7 @@ using GenRowMajor = LayoutRight; // Alias namespace detail { -// GGC8.5 WAR -- Use of lambdas in unevaluated contexts. Instead use function objects. +// For GCC8.5 -- Use of lambdas in unevaluated contexts. Instead use function objects. template struct CompactLambda; @@ -300,7 +300,7 @@ compact(Shape const& shape, CUTE_GCC_UNREACHABLE; } -// GCC8.5 WAR -- Specialization LayoutLeft +// For GCC8.5 -- Specialization LayoutLeft template <> struct CompactLambda { @@ -315,7 +315,7 @@ struct CompactLambda using seq = tuple_seq; // Seq }; -// GCC8.5 WAR -- Specialization LayoutRight +// For GCC8.5 -- Specialization LayoutRight template <> struct CompactLambda { @@ -419,8 +419,15 @@ CUTE_HOST_DEVICE constexpr auto compact_order(Shape const& shape, Order const& order) { - static_assert(is_congruent::value, "Need congruence of shape and order."); - return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order)); + if constexpr(is_congruent::value) { + return detail::compact_order(shape, order, flatten_to_tuple(shape), flatten_to_tuple(order)); + } + else + { + // Here we only want to apply order to top-level subshapes and default (col-major) order on other levels + static_assert(rank(Shape{}) == rank(Order{}), "Need equal rank of shape and order"); + return detail::compact_order(shape, order, shape, order); + } } template diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp index ec5ee81816..c8d910a03b 100644 --- a/include/cute/swizzle.hpp +++ b/include/cute/swizzle.hpp @@ -72,8 +72,7 @@ struct Swizzle static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{}); - template ::value)> + template CUTE_HOST_DEVICE constexpr static auto apply(Offset const& offset) @@ -81,8 +80,7 @@ struct Swizzle return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY } - template ::value)> + template CUTE_HOST_DEVICE constexpr auto operator()(Offset const& offset) const @@ -91,11 +89,6 @@ struct Swizzle } }; -// Translation for legacy SwizzleXor -// TODO: Deprecate -template -using SwizzleXor = Swizzle; - // // make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1> // make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5> @@ -131,6 +124,44 @@ composition(Swizzle, Swizzle) //return ComposedFn, Swizzle>{}; } +// +// Inverse +// + +template +CUTE_HOST_DEVICE constexpr +Swizzle +right_inverse(Swizzle const& sw) +{ + return sw; +} + +template +CUTE_HOST_DEVICE constexpr +Swizzle +left_inverse(Swizzle const& sw) +{ + return sw; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +right_inverse(T const& t) +{ + return -t; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +left_inverse(T const& t) +{ + return -t; +} + // // Upcast and Downcast // @@ -205,6 +236,8 @@ struct MixedBits // assert((dynamic_int_ & ~F) == 0); DynamicType dynamic_int_; + + CUTE_HOST_DEVICE constexpr operator uint32_t() const noexcept { return StaticInt | dynamic_int_; } }; template @@ -225,28 +258,6 @@ make_mixed_bits(constant const&, DynamicType const& d, constant const& CUTE_GCC_UNREACHABLE; } -// -// Explicit conversion for now -- consider casting on plus or minus -// - -template -CUTE_HOST_DEVICE constexpr -auto -to_integral(MixedBits const& m) -{ - //return S | (m.dynamic_int_ & F); - return S | m.dynamic_int_; -} - -// Any cute::is_integral -template ::value)> -CUTE_HOST_DEVICE constexpr -auto -to_integral(I const& i) -{ - return i; -} - // // Operators // @@ -383,6 +394,50 @@ operator^(constant const& s, MixedBits const& m) return m ^ s; } +template +CUTE_HOST_DEVICE constexpr +auto +operator<<(MixedBits const& m, constant const&) +{ + return make_mixed_bits(constant{}, + m.dynamic_int_ << S1, + constant{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator>>(MixedBits const& m, constant const&) +{ + return make_mixed_bits(constant> S1)>{}, + m.dynamic_int_ >> S1, + constant> S1)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +shiftl(MixedBits const& m, constant const& s) +{ + if constexpr (S1 >= 0) { + return m << s; + } else { + return m >> -s; + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +shiftr(MixedBits const& m, constant const& s) +{ + if constexpr (S1 >= 0) { + return m >> s; + } else { + return m << -s; + } +} + // // upcast and downcast // @@ -473,14 +528,14 @@ to_mixed_bits(Layout const& layout, Coord const& coord) template CUTE_HOST_DEVICE void print(MixedBits const& m) { - printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, to_integral(m)); + printf("M_%u|(%u&%u)=%u", S, uint32_t(m.dynamic_int_), F, uint32_t(m)); } #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) { - return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << to_integral(m); + return os << "M_" << S << "|(" << uint32_t(m.dynamic_int_) << "&" << F << ")=" << uint32_t(m); } template diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index 8303731eae..a5919716e2 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -33,219 +33,29 @@ #include #include +#include #include -/* This implements a ComposedLayout of the form - * InvolutionFn o OffsetPlus o Layout - * where the InvolutionFn need not be linear (hence the need for the Offset). +/* Specialized functionality for a ComposedLayout of the form + * InvolutionFn o Offset o LayoutB + * where the InvolutionFn is a Swizzle and is not linear (hence the need for the Offset). * - * This ComposedLayout provides similar coordinate-to-index mapping and layout manipulations, - * but is not considered a "normal" layout. - * For example, this layout provides size() functions, but does not provide stride() functions. + * Because these are specializations for core functions of ComposedLayout, these Swizzle Layouts + * provide similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but are not considered "normal" layouts. + * For example, these provide shape() and size() functions, but do not provide stride() functions. * - * Furthermore, for known InvolutionFns, this layout attempts to decay itself - * to a normal-layout with dynamic or static strides. - * This is possible by determining the subdomain of the Involution function - * that is identity and testing if the right Layout's codomain is contained - * within it. + * Furthermore, each of these specializations uses Swizzle<>-specific knowledge in its implementation and + * attempts to decay itself to a normal-layout with dynamic or static strides when certain slicing conditions + * are met. This is possible by determining the subdomain of the Swizzle<> function that is identity and + * testing if LayoutB's codomain is contained within it. In general, MizedBits is used as the Offset to track + * statically-vs-dynamically known bits in the Offset to improve the decay to static or dynamic normal layouts. */ namespace cute { -// A Layout of non-trivially composable functions: F o I o L -template -struct ComposedLayout - : private cute::tuple // EBO for static layouts -{ - CUTE_HOST_DEVICE constexpr - ComposedLayout(InvolutionFn const& fn = {}, - IntermediateOffset const& offset = {}, - Layout const& layout = {}) - : cute::tuple(fn, offset, layout) - {} - - // - // Accessors - // - - static constexpr int rank = Layout::rank; - - CUTE_HOST_DEVICE constexpr - decltype(auto) - swizzle_fn() const { - return get<0>(static_cast const&>(*this)); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - offset_fn() const { - return get<1>(static_cast const&>(*this)); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - layout_fn() const { - return get<2>(static_cast const&>(*this)); - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - layout() const { - return *this; - } - - CUTE_HOST_DEVICE constexpr - decltype(auto) - shape() const { - return layout_fn().shape(); - } - - // Doesn't really make sense to ask for the strides of this "layout" - CUTE_HOST_DEVICE constexpr - decltype(auto) - stride() const = delete; - - // - // Mappings - // - - // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) - // OR - // Slice the layout and return the sublayout (Coord has an Underscore slice op) - template - CUTE_HOST_DEVICE constexpr - auto - operator()(Coord const& coord) const { - if constexpr (has_underscore::value) { - return slice(coord, *this); - } else { - return swizzle_fn()(to_integral(offset_fn()) + layout_fn()(coord)); // (F o L)(c) - } - - CUTE_GCC_UNREACHABLE; - } - - // Map a 1D linear coordinate to a flat ND logical coordinate - template ::value)> - CUTE_HOST_DEVICE constexpr - auto - operator[](Int const& linear_idx) const { - return get_flat_coord(linear_idx); - } - - // Convenience function for multi-dimensional coordinates - template - CUTE_HOST_DEVICE constexpr - auto - operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { - return operator()(make_coord(c0,c1,cs...)); - } - - // - // Compose - // - - template - CUTE_HOST_DEVICE constexpr - auto - compose(OtherLayout const& other) const { - return composition(*this, other); - } - - template - CUTE_HOST_DEVICE constexpr - auto - compose(Layouts const&... layouts) const { - return composition(*this, make_tile(layouts...)); - } - - template - CUTE_HOST_DEVICE constexpr - auto - with_shape(OtherShape const& shape) const { - return composition(*this, make_layout(shape)); - } - - template - CUTE_HOST_DEVICE constexpr - auto - with_shape(Shapes const&... shapes) const { - return composition(*this, make_layout(make_shape(shapes...))); - } - - // - // Tile - // - - template - CUTE_HOST_DEVICE constexpr - auto - tile(OtherLayout const& other) const { - return tiled_divide(*this, other); - } - - template - CUTE_HOST_DEVICE constexpr - auto - tile(Layouts const&... layouts) const { - return tiled_divide(*this, make_tile(layouts...)); - } - - // - // Utility - // - - // - // Index to Coordinate - // - - // NOTE Only valid for compact layouts - - // Return the (hierarchical) ND logical coordinate corresponding to the linear index - // @post this->crd2idx(@a result) == idx - // @post congruent(@a result, shape()) - template ::value)> - CUTE_HOST_DEVICE constexpr - auto - get_hier_coord(IInt const& idx) const { - return layout_fn().get_hier_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) - } - - // Return the (flat) ND logical coordinate corresponding to the linear index - // @post this->crd2idx(@a result) == idx - // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 - template ::value)> - CUTE_HOST_DEVICE constexpr - auto - get_flat_coord(IInt const& idx) const { - return layout_fn().get_flat_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) - } - - // Return the generalized column-major 1D logical coordinate corresponding to the linear index - // @post this->crd2idx(@a result) == idx - // @post is_integral::value - template ::value)> - CUTE_HOST_DEVICE constexpr - auto - get_1d_coord(IInt const& idx) const { - return layout_fn().get_1d_coord(swizzle_fn()(idx) - to_integral(offset_fn())); // (L^-1 o F)(k) - } -}; - -template -struct is_layout> : true_type {}; - -template -struct is_composed_layout : false_type {}; -template -struct is_composed_layout> : true_type {}; - // // Constructors // @@ -258,22 +68,6 @@ make_layout(Swizzle const& sxor) return composition(sxor, Layout,Int<1>>{}); } -template -CUTE_HOST_DEVICE constexpr -auto -make_layout(ComposedLayout const& a, Layout const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), make_layout(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -make_layout(Layout const& a, ComposedLayout const& b) -{ - return composition(b.swizzle_fn(), b.offset_fn(), make_layout(a, b.layout_fn())); -} - namespace detail { template @@ -318,124 +112,57 @@ transfer_swizzle(Layout const& old_layout, } // end namespace detail -template +template CUTE_HOST_DEVICE constexpr auto -make_fragment_like(ComposedLayout,Offset,Layout> const& layout) +make_fragment_like(ComposedLayout,Offset,Layout> const& layout) { - return detail::transfer_swizzle(layout.layout_fn(), make_fragment_like(layout.layout_fn())); + return detail::transfer_swizzle(layout.layout_b(), make_fragment_like(layout.layout_b())); } // // Utilities // -// Return the layout of a mode -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -layout(ComposedLayout const& clayout) -{ - return composition(clayout.swizzle_fn(), clayout.offset_fn(), layout(clayout.layout_fn())); -} - -// Return the shape of a mode -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -shape(ComposedLayout const& layout) -{ - return shape(layout.layout_fn()); -} - -// Doesn't make sense to directly ask for the strides of this "layout" -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -stride(ComposedLayout const& layout) = delete; - -// Return the number of elements in a mode -template -CUTE_HOST_DEVICE constexpr -decltype(auto) -size(ComposedLayout const& layout) -{ - return size(layout.layout_fn()); -} - -// Return the number of modes -template -CUTE_HOST_DEVICE constexpr -auto -rank(ComposedLayout const& layout) -{ - return rank(layout.layout_fn()); -} - -// Return the depth of the layout -template -CUTE_HOST_DEVICE constexpr -auto -depth(ComposedLayout const& layout) -{ - return depth(layout.layout_fn()); -} - -// Return the codomain size of a mode -template -CUTE_HOST_DEVICE constexpr -auto -cosize(ComposedLayout const& layout) -{ - return cosize(layout.layout_fn()); -} - -// -// Operations to manipulate Layouts like a tuple of pairs -// +namespace detail { -template -CUTE_HOST_DEVICE constexpr +// Get just the Swizzle part of a composed layout. +template auto -get(ComposedLayout const& a) +get_swizzle_portion(ComposedLayout,Offset,LayoutB>) { - return composition(a.swizzle_fn(), a.offset_fn(), get(a.layout_fn())); + return Swizzle{}; } -template -CUTE_HOST_DEVICE constexpr +// A non-swizzled layout's "Swizzle part" is the identity swizzle. +template auto -take(ComposedLayout const& a) +get_swizzle_portion(Layout) { - return composition(a.swizzle_fn(), a.offset_fn(), take(a.layout_fn())); + return Swizzle<0,4,3>{}; } -template -CUTE_HOST_DEVICE constexpr +// Get the "non-swizzle" part of a composed layout, +// which is the underlying (non-composed) Layout. +template auto -flatten(ComposedLayout const& a) +get_nonswizzle_portion(ComposedLayout,Offset,LayoutB> const& slayout) { - return composition(a.swizzle_fn(), a.offset_fn(), flatten(a.layout_fn())); + return slayout.layout_b(); } -template -CUTE_HOST_DEVICE constexpr +// The non-swizzle part of a non-swizzled layout is just the Layout. +template auto -append(ComposedLayout const& a, X const& x) +get_nonswizzle_portion(Layout const& slayout) { - return composition(a.swizzle_fn(), a.offset_fn(), append(a.layout_fn(), x)); + return slayout; } -template -CUTE_HOST_DEVICE constexpr -auto -group(ComposedLayout const& a) -{ - return composition(a.swizzle_fn(), a.offset_fn(), group(a.layout_fn())); -} +} // namespace detail // -// Slice a ComposedLayout +// Slice a Swizzled ComposedLayout // namespace detail { @@ -491,7 +218,7 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout { if constexpr (all_underscore::value) { // Skip the expensive/complicated attempt to decay to a normal layout and just reshape - return cute::make_tuple(composition(layout.swizzle_fn(), layout.offset_fn(), slice(coord, layout.layout_fn())), Int<0>{}); + return cute::make_tuple(composition(layout.layout_a(), layout.offset(), slice(coord, layout.layout_b())), Int<0>{}); } else { // Projections of the swizzle layout for composition @@ -503,7 +230,7 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout make_stride( Int<0>{}, stride<1>(sw), Int<0>{}, stride<3>(sw), Int<0>{})); // The portion of the layout that is not yet consumed - auto sliced_layout = slice(coord, layout.layout_fn()); + auto sliced_layout = slice(coord, layout.layout_b()); // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay @@ -513,20 +240,20 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); // Determine if any active bits collide under the swizzle - auto hit_ZandY = !(swizzle_active_bits & ~layout.swizzle_fn()(swizzle_active_bits)); + auto hit_ZandY = !(swizzle_active_bits & ~layout.layout_a()(swizzle_active_bits)); // The portion of the layout that we are consuming now - auto diced_layout = dice(coord, layout.layout_fn()); + auto diced_layout = dice(coord, layout.layout_b()); auto diced_coord = dice(coord, coord); auto diced_layout_anti_zy = composition(swizzle_anti_zy, diced_layout); auto diced_layout_only_zy = composition(swizzle_only_zy, diced_layout); // New swizzle and offset - auto swizzle = layout.swizzle_fn(); - // offset_only_zy interacts with swizzle and gets accumulated with layout.offset_fn() + auto swizzle = layout.layout_a(); + // offset_only_zy interacts with swizzle and gets accumulated with layout.offset() // being careful about the static/dynamic contributions from diced_layout and diced_coord - auto offset_only_zy = layout.offset_fn() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); + auto offset_only_zy = layout.offset() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); // offset_anti_zy always gets passed through, no interaction with swizzle auto offset_anti_zy = diced_layout_anti_zy(diced_coord); @@ -553,50 +280,17 @@ slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout // Decay to a normal layout with offset return cute::make_tuple(composition(swizzle_layout, sliced_layout), - swizzle(to_integral(offset_only_zy)) + offset_anti_zy); + swizzle(offset_only_zy) + offset_anti_zy); } } CUTE_GCC_UNREACHABLE; } -template -CUTE_HOST_DEVICE constexpr -auto -slice(Coord const& coord, ComposedLayout const& layout) -{ - return get<0>(slice_and_offset(coord, layout)); -} - // // composition // -template -CUTE_HOST_DEVICE constexpr -auto -composition(Swizzle const& sxor, - Offset const& offset, - Layout const& layout) -{ - return ComposedLayout>{sxor, offset, layout}; -} - -template -CUTE_HOST_DEVICE constexpr -auto -composition(Swizzle const& sxor, - Offset const& offset, - ComposedLayout const& layout) -{ - // Assume disjoint swizzles and offsets for commutivity - return composition(composition(sxor,layout.swizzle_fn()), offset ^ layout.offset_fn(), layout.layout_fn()); -} - // Ignore identity case template @@ -619,16 +313,6 @@ composition(Swizzle const& sxor, return composition(sxor, Int<0>{}, layout); } -template -CUTE_HOST_DEVICE constexpr -auto -composition(ComposedLayout const& a, - LayoutOrTile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), composition(a.layout_fn(), b)); -} - template CUTE_HOST_DEVICE constexpr @@ -645,106 +329,69 @@ composition(Layout const& a, return composition(make_swizzle(), a); } -template -CUTE_HOST_DEVICE constexpr -auto -composition(Layout const& a, - ComposedLayout const& b) -{ - CUTE_STATIC_ASSERT_V(b.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); - - return composition(composition(a, b.swizzle_fn()), b.layout_fn()); -} - -template -CUTE_HOST_DEVICE constexpr -auto -composition(ComposedLayout const& a, - ComposedLayout const& b) -{ - auto asb = composition(a.layout_fn(), b); - - return composition(composition(a.swizzle_fn(),asb.swizzle_fn()), asb.offset_fn(), asb.layout_fn()); -} - -// -// complement -// - -template -CUTE_HOST_DEVICE constexpr -auto -complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) -{ - // Assume there is no swizzle component in the complement - return complement(layout.layout_fn(), cosize_hi); -} - -template -CUTE_HOST_DEVICE constexpr -auto -complement(ComposedLayout const& layout) -{ - return complement(layout, cosize(layout)); -} - // // inverse // -template +// Specialization to attempt to pass-through the Swizzle back to the left -- Needed? +template CUTE_HOST_DEVICE constexpr auto -right_inverse(ComposedLayout const& layout) +right_inverse(ComposedLayout,Offset,Layout> const& layout) { - CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); - return composition(right_inverse(layout.layout_fn()), layout.swizzle_fn()); + if constexpr (is_constant<0, Offset>::value) { + return composition(right_inverse(layout.layout_b()), layout.layout_a()); + } else { + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); + } } -template +// Specialization to attempt to pass-through the Swizzle back to the left -- Needed? +template CUTE_HOST_DEVICE constexpr auto -left_inverse(ComposedLayout const& layout) +left_inverse(ComposedLayout,Offset,Layout> const& layout) { - CUTE_STATIC_ASSERT_V(layout.offset_fn() == Int<0>{}, "Requires 0-offset."); - return composition(left_inverse(layout.layout_fn()), layout.swizzle_fn()); + if constexpr (is_constant<0, Offset>::value) { + return composition(left_inverse(layout.layout_b()), layout.layout_a()); + } else { + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); + } } // // Other operations // -template +template CUTE_HOST_DEVICE constexpr auto -max_common_vector(ComposedLayout,Offset,SLayout> const& a, +max_common_vector(ComposedLayout,Offset,LayoutB> const& a, Layout const& b) { // This assumes that Offset is in the YZ domain of the Swizzle... - return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_fn(), b)); + return cute::min(Int<(1 << M)>{}, max_common_vector(a.layout_b(), b)); } -template +template CUTE_HOST_DEVICE constexpr auto max_common_vector(Layout const& a, - ComposedLayout,Offset,SLayout> const& b) + ComposedLayout,Offset,LayoutB> const& b) { return max_common_vector(b, a); } -template +template CUTE_HOST_DEVICE constexpr auto -max_common_vector(ComposedLayout,Offset0,SLayout0> const& a, - ComposedLayout,Offset1,SLayout1> const& b) +max_common_vector(ComposedLayout,Offset0,LayoutB0> const& a, + ComposedLayout,Offset1,LayoutB1> const& b) { auto result = coalesce(composition(a, right_inverse(b))); - if constexpr (is_constant<1, decltype(stride<0>(result.layout_fn()))>::value) { + if constexpr (is_constant<1, decltype(stride<0>(result.layout_b()))>::value) { return shape<0>(result); } else { return Int<1>{}; @@ -753,132 +400,6 @@ max_common_vector(ComposedLayout,Offset0,SLayout0> const& a, CUTE_GCC_UNREACHABLE; } -template -CUTE_HOST_DEVICE constexpr -auto -zip(ComposedLayout const& a) -{ - return composition(a.swizzle_fn(), a.offset_fn(), zip(a.layout_fn())); -} - -// Partitions - -template -CUTE_HOST_DEVICE constexpr -auto -logical_divide(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), logical_divide(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -tile_unzip(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), tile_unzip(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -tiled_divide(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), tiled_divide(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -zipped_divide(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), zipped_divide(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -logical_product(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), logical_product(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -tiled_product(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), tiled_product(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -blocked_product(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), blocked_product(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -raked_product(ComposedLayout const& a, - Tile const& b) -{ - return composition(a.swizzle_fn(), a.offset_fn(), raked_product(a.layout_fn(), b)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -tile_to_shape(ComposedLayout const& layout, - Shape const& trg_shape, - ModeOrder const& ord_shape = {}) -{ - return composition(layout.swizzle_fn(), layout.offset_fn(), tile_to_shape(layout.layout_fn(), trg_shape, ord_shape)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -filter(ComposedLayout const& layout, Shape const& trg_profile) -{ - return composition(layout.swizzle_fn(), layout.offset_fn(), filter(layout.layout_fn(), trg_profile)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -coalesce(ComposedLayout const& layout) -{ - return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn())); -} - -template -CUTE_HOST_DEVICE constexpr -auto -coalesce(ComposedLayout const& layout, Shape const& trg_profile) -{ - return composition(layout.swizzle_fn(), layout.offset_fn(), coalesce(layout.layout_fn(), trg_profile)); -} - /////////////////////////////////////////////////////////////////////////////// // ComposedLayout as second argument is often more difficult... @@ -889,10 +410,10 @@ auto logical_product(Layout const& block, ComposedLayout,Offset,LayoutT> const& tile) { - CUTE_STATIC_ASSERT_V(tile.offset_fn() == Int<0>{}, "Require Swizzle offset == 0."); + CUTE_STATIC_ASSERT_V(tile.offset() == Int<0>{}, "Require Swizzle offset == 0."); // The new layout -- if swizzle wasn't an issue, this is the result // our goal is to determine a new swizzle for these strides - auto new_layout = logical_product(block, tile.layout_fn()); + auto new_layout = logical_product(block, tile.layout_b()); // This is accomplished by identifying // S o L :=: S? o L* @@ -906,7 +427,7 @@ logical_product(Layout const& block, make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] - auto layout_only_zy = composition(swizzle_only_zy, tile.layout_fn()); + auto layout_only_zy = composition(swizzle_only_zy, tile.layout_b()); // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); // Get the Z bit and the Y bits @@ -914,99 +435,12 @@ logical_product(Layout const& block, auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) - auto new_active_Z = new_layout(Int<0>{}, tile.layout_fn()[active_Z]); - auto new_active_Y = new_layout(Int<0>{}, tile.layout_fn()[active_Y]); + auto new_active_Z = new_layout(Int<0>{}, tile.layout_b()[active_Z]); + auto new_active_Y = new_layout(Int<0>{}, tile.layout_b()[active_Y]); // Use this new swizzle identifier to construxt the new swizzle for new_layout // (this also makes sure it's a "valid" swizzle that Swizzle can represent) return composition(make_swizzle(), new_layout); } -template -CUTE_HOST_DEVICE constexpr -auto -tiled_product(Layout const& block, - ComposedLayout const& tile) -{ - /// Avoid swizzle slice - auto result = logical_product(block, tile); - return composition(result.swizzle_fn(), result.offset_fn(), result.layout_fn()(_, repeat>(_))); -} - -template -CUTE_HOST_DEVICE constexpr -auto -blocked_product(Layout const& block, - ComposedLayout const& layout) -{ - constexpr int R = cute::max(rank_v, rank_v); - auto padded_block = append(block, Layout<_1,_0>{}); - auto padded_layout = append(layout, Layout<_1,_0>{}); - - auto result = logical_product(padded_block, padded_layout); - - return composition(result.swizzle_fn(), - result.offset_fn(), - coalesce(zip(get<0>(result.layout_fn()), get<1>(result.layout_fn())), repeat(Int<1>{}))); -} - -// -// Upcast and Downcast -// - -template -CUTE_HOST_DEVICE constexpr -auto -upcast(ComposedLayout const& layout) -{ - return composition(upcast(layout.swizzle_fn()), upcast(layout.offset_fn()), upcast(layout.layout_fn())); -} - -template -CUTE_HOST_DEVICE constexpr -auto -downcast(ComposedLayout const& layout) -{ - return composition(downcast(layout.swizzle_fn()), downcast(layout.offset_fn()), downcast(layout.layout_fn())); -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(ComposedLayout const& layout) -{ - if constexpr (sizeof(NewType) == sizeof(OldType)) { - return layout; - } else if constexpr (sizeof(NewType) > sizeof(OldType)) { - static_assert(sizeof(NewType) % sizeof(OldType) == 0, "NewType must be a multiple of OldType"); - return upcast(layout); - } else if constexpr (sizeof(NewType) < sizeof(OldType)) { - static_assert(sizeof(OldType) % sizeof(NewType) == 0, "NewType must be a divisor of OldType"); - return downcast(layout); - } - - CUTE_GCC_UNREACHABLE; -} - -// -// Display utilities -// - -template -CUTE_HOST_DEVICE void print(ComposedLayout const& layout) -{ - print(layout.swizzle_fn()); print(" o "); print(layout.offset_fn()); print(" o "); print(layout.layout_fn()); -} - -#if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) -{ - return os << layout.swizzle_fn() << " o " << layout.offset_fn() << " o " << layout.layout_fn(); -} -#endif - } // end namespace cute diff --git a/include/cute/swizzle_ptr.hpp b/include/cute/swizzle_ptr.hpp index 17ff3bcd1b..50bfbfa2dd 100644 --- a/include/cute/swizzle_ptr.hpp +++ b/include/cute/swizzle_ptr.hpp @@ -34,7 +34,6 @@ #include -#include #include #include @@ -119,11 +118,20 @@ struct is_smem> : true_type {}; template CUTE_HOST_DEVICE constexpr auto -make_smem_ptr(T* ptr, Swizzle const& swizzle) +make_smem_ptr(T* ptr, Swizzle const&) { return smem_ptr_swizzle{ptr}; } +// Specialization for immediate decay +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(T* ptr, Swizzle<0,M,S> const&) +{ + return make_smem_ptr(ptr); +} + // A model of a nullptr smem_ptr with B == sizeof_bits::value // That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr template @@ -140,25 +148,8 @@ make_tensor(smem_ptr const& ptr, ComposedLayout,Layout> const& layout) { static_assert(B == sizeof_bits::value, "Expected a B-bit pointer type."); - return make_tensor(make_smem_ptr(ptr.get(), layout.swizzle_fn()), - layout.layout_fn()); -} - -// Specialization for immediate decay -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor(smem_ptr_swizzle>& p, Layout const& layout) -{ - return make_tensor(make_smem_ptr(p.ptr_), layout); -} - -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor(smem_ptr_swizzle> const& p, Layout const& layout) -{ - return make_tensor(make_smem_ptr(p.ptr_), layout); + return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()), + layout.layout_b()); } // NOTE: To preserve smem_ptr_flag_bits under recast ops @@ -167,7 +158,7 @@ CUTE_HOST_DEVICE constexpr auto upcast(ComposedLayout,Layout> const& layout) { - return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, upcast(layout.layout_fn())); + return composition(layout.layout_a(), smem_ptr_flag_bits{}, upcast(layout.layout_b())); } template @@ -175,7 +166,7 @@ CUTE_HOST_DEVICE constexpr auto downcast(ComposedLayout,Layout> const& layout) { - return composition(layout.swizzle_fn(), smem_ptr_flag_bits{}, downcast(layout.layout_fn())); + return composition(layout.layout_a(), smem_ptr_flag_bits{}, downcast(layout.layout_b())); } // @@ -199,6 +190,13 @@ recast(smem_ptr_swizzle const& ptr) return smem_ptr_swizzle{recast(ptr.ptr_)}; } +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(smem_ptr_swizzle ptr) { + return ptr.get(); +} + // // Conversion with swizzle_layout // @@ -208,7 +206,7 @@ CUTE_HOST_DEVICE auto as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) { - return composition(recast,uint_bit_t>(layout.swizzle_fn()), Int<0>{}, layout.layout_fn()); + return composition(recast,uint_bit_t>(layout.layout_a()), Int<0>{}, layout.layout_b()); } template @@ -231,8 +229,8 @@ auto as_position_independent_swizzle_tensor(Tensor>, Layout>& tensor) { { - uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); - uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); + [[maybe_unused]] uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); + [[maybe_unused]] uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle } auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); @@ -247,6 +245,24 @@ as_position_independent_swizzle_tensor(Tensor +CUTE_HOST_DEVICE constexpr +auto +as_position_independent_swizzle_tensor(Tensor const& tensor) +{ + return tensor; +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_position_independent_swizzle_tensor(Tensor&& tensor) +{ + return tensor; +} + // // Print // @@ -257,8 +273,8 @@ CUTE_HOST_DEVICE void print_latex(ComposedLayout,Layout> const& layout) { - auto new_swizzle = recast,uint_bit_t>(layout.swizzle_fn()); - print_latex(composition(new_swizzle, Int<0>{}, layout.layout_fn())); + auto new_swizzle = recast,uint_bit_t>(layout.layout_a()); + print_latex(composition(new_swizzle, Int<0>{}, layout.layout_b())); } template diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index deac7b26ce..c4c89de3dd 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -411,11 +411,7 @@ CUTE_HOST_DEVICE constexpr auto make_tensor_like(Layout const& layout) { - if constexpr (is_static::value) { - return make_tensor(make_ordered_layout(layout)); - } else { - return make_tensor(make_layout(layout.shape())); - } + return make_tensor(make_layout_like(layout)); } template @@ -465,8 +461,23 @@ make_fragment_like(Tensor const& tensor) return make_fragment_like(tensor.layout()); } +// +// make_counting_tensor +// Make a tensor from a layout by binding it to a counting iter with 0-offset of the same profile as the codomain. +// + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +make_counting_tensor(Layout const& layout) +{ + return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(coshape(layout), Int<0>{}))), + layout); +} + // // make_identity_tensor +// Make a tensor that maps coordinates within a shape to themselves. // template @@ -474,8 +485,7 @@ CUTE_HOST_DEVICE constexpr auto make_identity_tensor(Shape const& shape) { - return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(shape, Int<0>{}))), - make_identity_layout(shape)); + return make_counting_tensor(make_identity_layout(shape)); } // @@ -596,6 +606,44 @@ coalesce(Tensor&& tensor, Profile const& profile) return make_tensor(std::forward(tensor).data(), coalesce(tensor.layout(), profile)); } +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor&& tensor) +{ + return make_tensor(cute::forward(tensor).data(), filter_zeros(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor&& tensor) +{ + return make_tensor(std::forward(tensor).data(), filter(tensor.layout())); +} + +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor&& tensor, Profile const& profile) +{ + return make_tensor(std::forward(tensor).data(), filter(tensor.layout(), profile)); +} + +// Return a tensor with the same shape as input but offset by a given coordinate +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Tensor&& tensor) +{ + auto [layout, ptr_offset] = domain_offset(coord, tensor.layout()); + return make_tensor(std::forward(tensor).data() + ptr_offset, layout); +} + // Group the modes [B,E) into a single mode // e.g. group<2,4>(make_tensor(Layout>{})) // => make_tensor(Layout,_5,_6>>{}) @@ -642,13 +690,28 @@ recast(Tensor&& tensor, type_list) CUTE_GCC_UNREACHABLE; } -template >::value)> +template +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor const& tensor) +{ + return recast(tensor, type_list{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor& tensor) +{ + return recast(tensor, type_list{}); +} + +template CUTE_HOST_DEVICE constexpr auto -recast(Tensor&& tensor) +recast(Tensor&& tensor) { - return recast(std::forward(tensor), type_list{}); + return recast(std::forward>(tensor), type_list{}); } // @@ -840,9 +903,17 @@ local_tile(Tensor && tensor, // Display utilities // +template +CUTE_HOST_DEVICE void print(Tensor const& tensor) +{ + print(tensor.data()); print(" o "); print(tensor.layout()); +} + template CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) { + print(tensor); print(":\n"); + auto format = get_format(tensor(0)); using type = typename decltype(format)::type; @@ -880,12 +951,7 @@ CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) } } -template -CUTE_HOST_DEVICE void print(Tensor const& tensor) -{ - print(tensor.layout()); print("\n"); - print_tensor(tensor); -} + #if !defined(__CUDACC_RTC__) template @@ -943,7 +1009,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const // #include - // // Tensor Algorithms // diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp index 83e842946b..3e7df55867 100644 --- a/include/cute/util/debug.hpp +++ b/include/cute/util/debug.hpp @@ -84,7 +84,7 @@ namespace cute __FILE__, __LINE__, #e, \ cudaGetErrorName(code), cudaGetErrorString(code)); \ fflush(stderr); \ - exit(0); \ + exit(1); \ } \ } while (0) #endif @@ -98,15 +98,16 @@ namespace cute #endif // A dummy function that uses compilation failure to print a type -template +template CUTE_HOST_DEVICE void print_type() { - static_assert(sizeof(T) < 0, "Printing type T."); + static_assert(sizeof...(T) < 0, "Printing type T."); } -template + +template CUTE_HOST_DEVICE void -print_type(T&&) { - static_assert(sizeof(T) < 0, "Printing type T."); +print_type(T&&...) { + static_assert(sizeof...(T) < 0, "Printing type T."); } // diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp index 320b4f5b99..88d2a9307d 100644 --- a/include/cute/util/print.hpp +++ b/include/cute/util/print.hpp @@ -124,7 +124,7 @@ print(char const& c) { } template ::value)> + __CUTE_REQUIRES(is_std_integral::value)> CUTE_HOST_DEVICE void print(T const& a) { @@ -138,4 +138,11 @@ print(char const* format, T const&... t) { printf(format, t...); } +template +CUTE_HOST_DEVICE +void +print(T const&... t) { + (print(t), ...); +} + } // end namespace cute diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index 28e53597a9..e951e901c4 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -74,12 +74,15 @@ using CUTE_STL_NAMESPACE::is_void_v; using CUTE_STL_NAMESPACE::is_base_of; using CUTE_STL_NAMESPACE::is_base_of_v; +using CUTE_STL_NAMESPACE::is_const_v; + // using CUTE_STL_NAMESPACE::true_type; // using CUTE_STL_NAMESPACE::false_type; using CUTE_STL_NAMESPACE::conditional; using CUTE_STL_NAMESPACE::conditional_t; +using CUTE_STL_NAMESPACE::remove_const_t; using CUTE_STL_NAMESPACE::remove_cv_t; using CUTE_STL_NAMESPACE::remove_reference_t; @@ -89,6 +92,9 @@ using CUTE_STL_NAMESPACE::remove_extent; using CUTE_STL_NAMESPACE::decay; using CUTE_STL_NAMESPACE::decay_t; +using CUTE_STL_NAMESPACE::is_lvalue_reference; +using CUTE_STL_NAMESPACE::is_lvalue_reference_v; + using CUTE_STL_NAMESPACE::is_reference; using CUTE_STL_NAMESPACE::is_trivially_copyable; @@ -97,8 +103,16 @@ using CUTE_STL_NAMESPACE::is_same_v; using CUTE_STL_NAMESPACE::is_arithmetic; using CUTE_STL_NAMESPACE::is_unsigned; +using CUTE_STL_NAMESPACE::is_unsigned_v; using CUTE_STL_NAMESPACE::is_signed; +using CUTE_STL_NAMESPACE::is_signed_v; + +using CUTE_STL_NAMESPACE::make_signed; +using CUTE_STL_NAMESPACE::make_signed_t; + // using CUTE_STL_NAMESPACE::is_integral; +template +using is_std_integral = CUTE_STL_NAMESPACE::is_integral; using CUTE_STL_NAMESPACE::is_empty; @@ -107,6 +121,26 @@ using CUTE_STL_NAMESPACE::invoke_result_t; // using CUTE_STL_NAMESPACE::declval; +template< class T > +constexpr T&& forward(remove_reference_t& t) noexcept +{ + return static_cast(t); +} + +template< class T > +constexpr T&& forward(remove_reference_t&& t) noexcept +{ + static_assert(! is_lvalue_reference_v, + "T cannot be an lvalue reference (e.g., U&)."); + return static_cast(t); +} + +template< class T > +constexpr remove_reference_t&& move( T&& t ) noexcept +{ + return static_cast&&>(t); +} + // using CUTE_STL_NAMESPACE::numeric_limits; diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index 1d52141cc4..6b491404e0 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -36,18 +36,17 @@ #include #include -namespace cutlass { -/// @brief -namespace arch { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) #define CUDA_BARRIER_ENABLED 1 #else #define CUDA_BARRIER_ENABLED 0 #endif +namespace cutlass { +/// @brief +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// class NamedBarrier { // Data Members: @@ -65,9 +64,9 @@ class NamedBarrier { NamedBarrier(uint32_t num_threads, uint32_t id = 0) : num_threads_(num_threads), id_(id) {} - CUTLASS_DEVICE - void arrive_and_wait() const { - NamedBarrier::arrive_and_wait(num_threads_, id_); + CUTLASS_DEVICE + void arrive_and_wait() const { + NamedBarrier::arrive_and_wait(num_threads_, id_); } CUTLASS_DEVICE @@ -75,13 +74,13 @@ class NamedBarrier { NamedBarrier::arrive(num_threads_, id_); } - CUTLASS_DEVICE - void sync() const { - NamedBarrier::arrive_and_wait(); + CUTLASS_DEVICE + void sync() const { + NamedBarrier::arrive_and_wait(); } // Static variants - CUTLASS_DEVICE + CUTLASS_DEVICE static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { #if CUDA_BARRIER_ENABLED asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); @@ -123,7 +122,7 @@ struct ClusterBarrier { CUTLASS_DEVICE ClusterBarrier() = delete; - CUTLASS_DEVICE + CUTLASS_DEVICE void init(uint32_t arrive_count) const { ClusterBarrier::init(&this->barrier_, arrive_count); } @@ -305,36 +304,44 @@ struct ClusterTransactionBarrier : public ClusterBarrier { CUTLASS_DEVICE ClusterTransactionBarrier() = delete; - // Performs an arrive operation + bytes reset + // Performs an arrive operation + expected transaction bytes increment CUTLASS_DEVICE - void arrive_and_reset_bytes(uint32_t transaction_bytes) const { - ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes); + void arrive_and_expect_tx(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes); } - // Performs an arrive operation + bytes reset + // Performs an arrive operation + expected transaction bytes increment CUTLASS_DEVICE - void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { - ClusterTransactionBarrier::arrive_and_reset_bytes(&this->barrier_, transaction_bytes , cta_id, true); + void arrive_and_expect_tx(uint32_t transaction_bytes, uint32_t cta_id) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes , cta_id, true); } + // Performs an expected transaction bytes increment without doing an arrive operation CUTLASS_DEVICE - void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { + void expect_transaction(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::expect_transaction(&this->barrier_, transaction_bytes); + } + + // Performs an expected transaction bytes decrement without doing an arrive operation + CUTLASS_DEVICE + void complete_transaction(uint32_t transaction_bytes, uint32_t pred = 1) const { uint32_t cta_rank = cute::block_rank_in_cluster(); - ClusterTransactionBarrier::commit(&this->barrier_, cta_rank, transaction_bytes, pred); + ClusterTransactionBarrier::complete_transaction(&this->barrier_, cta_rank, transaction_bytes, pred); } + // Performs an expected transaction bytes decrement without doing an arrive operation CUTLASS_DEVICE - void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { - ClusterTransactionBarrier::commit(&this->barrier_, dst_cta_id, transaction_bytes, pred); + void complete_transaction(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + ClusterTransactionBarrier::complete_transaction(&this->barrier_, dst_cta_id, transaction_bytes, pred); } // // Static Versions // - // Performs an arrive operation + bytes reset + // Performs an arrive operation + expected transaction bytes increment CUTLASS_DEVICE - static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + static void arrive_and_expect_tx(ValueType const* smem_ptr, uint32_t transaction_bytes) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( @@ -348,9 +355,9 @@ struct ClusterTransactionBarrier : public ClusterBarrier { #endif } - // Performs an arrive operation + bytes reset for a remote cta_id in a Cluster + // Performs an arrive operation + expected transaction bytes increment for a remote cta_id in a Cluster CUTLASS_DEVICE - static void arrive_and_reset_bytes( + static void arrive_and_expect_tx( ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -369,9 +376,9 @@ struct ClusterTransactionBarrier : public ClusterBarrier { #endif } - // Performs an bytes reset without doing an arrive operation + // Performs an expected transaction bytes increment without doing an arrive operation CUTLASS_DEVICE - static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + static void expect_transaction(ValueType const* smem_ptr, uint32_t transaction_bytes) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); asm volatile( @@ -385,9 +392,9 @@ struct ClusterTransactionBarrier : public ClusterBarrier { #endif } - // Increments transaction bytes in the barrier + // Performs an expected transaction bytes decrement without doing an arrive operation CUTLASS_DEVICE - static void commit( + static void complete_transaction( ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { #if CUDA_BARRIER_ENABLED uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); @@ -404,6 +411,46 @@ struct ClusterTransactionBarrier : public ClusterBarrier { asm volatile ("brkpt;\n" ::); #endif } + + // + // DEPRECATED APIs + // + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes) const { + arrive_and_expect_tx(transaction_bytes); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { + arrive_and_expect_tx(transaction_bytes, cta_id); + } + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE + void reset_bytes(uint32_t transaction_bytes) const { + expect_transaction(transaction_bytes); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { + complete_transaction(transaction_bytes, pred); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + complete_transaction(dst_cta_id, transaction_bytes, pred); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + arrive_and_expect_tx(smem_ptr, transaction_bytes); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { + arrive_and_expect_tx(smem_ptr, transaction_bytes, cta_id, pred); + } + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE + static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + expect_transaction(smem_ptr, transaction_bytes); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + static void commit(ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { + complete_transaction(smem_ptr, dst_cta_id, transaction_bytes, pred); + } }; // Helps with visibility of barrier init operations across warps / cta / cluster diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index 04bab1d66e..a8fd33042d 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -326,7 +326,6 @@ struct cp_async { "cp.async only supports CacheOperation::Global when access size is 16B."); unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - asm volatile( "{\n" " .reg .pred p;\n" @@ -365,7 +364,6 @@ struct cp_async_zfill { unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); int src_in_bytes = (pred_guard ? SizeInBytes : 0); - asm volatile( #if CUTLASS_ENABLE_L2_PREFETCH "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), @@ -403,7 +401,6 @@ struct cp_async_nan<16, CacheOperation::Global> { OOB_NAN_F16x2, OOB_NAN_F16x2}; unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); - asm volatile( "{\n" " .reg .pred p;\n" diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index f35cdb349d..f7c59e6330 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -99,6 +99,11 @@ struct OpXorPopc {}; ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Tag indicating the inner product is defined by (AND, POPC) +struct OpAndPopc {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Tag classifying math operators as thread-level operations. struct OpClassSimt {}; @@ -113,6 +118,11 @@ struct OpClassWmmaTensorOp {}; ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Tag classifing operators as Tensor Core with structure sparse operations. +struct OpClassSparseTensorOp {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Matrix multiply-add operation template < /// Size of the matrix product (concept: GemmShape) diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 1402e76da6..4d6c63102c 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -1248,7 +1248,7 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) -#if (__CUDA_ARCH__ >= 900) || (defined(CUTLASS_ARCH_WMMA_ENABLED)) +#if defined(CUTLASS_ARCH_WMMA_ENABLED) using WmmaFragmentA = nvcuda::wmma::fragment< nvcuda::wmma::matrix_a, Shape::kM, diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index 8682ae1ba8..c01a7b07c4 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -2039,6 +2039,77 @@ struct Mma< } }; +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 +template <> +struct Mma< + gemm::GemmShape<16,8,256>, + 32, + cutlass::uint1b_t, + layout::RowMajor, + cutlass::uint1b_t, + layout::ColumnMajor, + int32_t, + layout::RowMajor, + OpAndPopc> { + + using Shape = gemm::GemmShape<16,8,256>; + + using ElementA = cutlass::uint1b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint1b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int32_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpAndPopc; + using ArchTag = arch::Sm80; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c + ) const { + +#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + /// Matrix multiply-add operation: S32 = B1 & B1 + S32 template <> struct Mma< diff --git a/include/cutlass/arch/simd_sm60.h b/include/cutlass/arch/simd_sm60.h index 16d528b5f1..d8c38aae5e 100644 --- a/include/cutlass/arch/simd_sm60.h +++ b/include/cutlass/arch/simd_sm60.h @@ -50,8 +50,6 @@ template <> Array operator*(Array const &a, Array const &b) { Array d; - // TODO - return d; } @@ -60,8 +58,6 @@ template <> Array operator+(AArray const &a, Array const &b) { Array d; - // TODO - return d; } @@ -70,8 +66,6 @@ template <> Array operator-(Array const &a, Array const &b) { Array d; - // TODO - return d; } @@ -83,8 +77,6 @@ template <> Array mac(Array const &a, Array const &b, Array const &c) { Array d; - // TODO - return d; } @@ -95,8 +87,6 @@ CUTLASS_HOST_DEVICE template <> half_t dot(Array const &a, Array const &b, half_t accum) { - // TODO - return accum; } @@ -105,8 +95,6 @@ CUTLASS_HOST_DEVICE template <> float dot(Array const &a, Array const &b, float accum) { - // TODO - return accum; } diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 9fe245beeb..19d16cc251 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -38,7 +38,6 @@ #include "cutlass/functional.h" #include "cutlass/numeric_types.h" #include "cutlass/half.h" - namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -717,6 +716,27 @@ struct multiplies> { } }; +template +struct scale> { + T const scaling_factor_; + + CUTLASS_HOST_DEVICE + scale(T scaling_factor) : scaling_factor_(scaling_factor) { + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & rhs) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = rhs[i] * scaling_factor_; + } + + return result; + } +}; + template struct divides> { @@ -764,13 +784,13 @@ struct divides> { }; template -struct maximum> { +struct maximum, false> { CUTLASS_HOST_DEVICE Array operator()(Array const &lhs, Array const &rhs) const { Array result; - maximum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -784,7 +804,7 @@ struct maximum> { Array operator()(Array const &lhs, T const &scalar) const { Array result; - maximum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -795,10 +815,56 @@ struct maximum> { } CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct maximum, true> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { Array result; - maximum scalar_op; + maximum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -810,7 +876,7 @@ struct maximum> { }; template -struct minimum> { +struct minimum, false> { CUTLASS_HOST_DEVICE static T scalar_op(T const &lhs, T const &rhs) { @@ -821,7 +887,7 @@ struct minimum> { Array operator()(Array const &lhs, Array const &rhs) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -835,7 +901,7 @@ struct minimum> { Array operator()(Array const &lhs, T const &scalar) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -846,10 +912,61 @@ struct minimum> { } CUTLASS_HOST_DEVICE - Array operator()( T const &scalar, Array const &rhs) const { + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum, true> { + + CUTLASS_HOST_DEVICE + static T scalar_op(T const &lhs, T const &rhs) { + return (rhs < lhs ? rhs : lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { Array result; - minimum scalar_op; + minimum scalar_op; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { @@ -1013,7 +1130,7 @@ struct plus> { result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); @@ -1046,7 +1163,7 @@ struct plus> { result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); @@ -1078,7 +1195,7 @@ struct plus> { result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); @@ -1113,7 +1230,7 @@ struct minus> { result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); @@ -1146,7 +1263,7 @@ struct minus> { result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); @@ -1178,7 +1295,7 @@ struct minus> { result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); @@ -1213,7 +1330,7 @@ struct multiplies> { result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); @@ -1246,7 +1363,7 @@ struct multiplies> { result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmul( @@ -1281,7 +1398,7 @@ struct multiplies> { result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hmul( @@ -1319,7 +1436,7 @@ struct divides> { result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); @@ -1355,7 +1472,7 @@ struct divides> { result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hdiv( @@ -1390,7 +1507,7 @@ struct divides> { result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hdiv( @@ -1427,7 +1544,7 @@ struct negate> { result_ptr[i] = __hneg2(source_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { half_t x = lhs[N - 1]; __half lhs_val = -reinterpret_cast<__half const &>(x); result[N - 1] = reinterpret_cast(lhs_val); @@ -1468,7 +1585,7 @@ struct multiply_add, Array, Array> { result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); @@ -1514,7 +1631,7 @@ struct multiply_add, Array, Array> { result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); @@ -1558,7 +1675,7 @@ struct multiply_add, Array, Array> { result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); @@ -1603,7 +1720,7 @@ struct multiply_add, Array, Array> { result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); @@ -1653,7 +1770,7 @@ struct multiply_add_relu0, Array, Array> result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); @@ -1700,7 +1817,7 @@ struct multiply_add_relu0, Array, Array> result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); @@ -1745,7 +1862,7 @@ struct multiply_add_relu0, Array, Array> result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); @@ -1791,7 +1908,7 @@ struct multiply_add_relu0, Array, Array> result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); @@ -1820,7 +1937,7 @@ struct multiply_add_relu0, Array, Array> }; template -struct minimum> { +struct minimum, false> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; @@ -1835,7 +1952,7 @@ struct minimum> { result_ptr[i] = __hmin2(lhs_ptr[i], rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); @@ -1871,7 +1988,7 @@ struct minimum> { result_ptr[i] = __hmin2(lhs_pair, rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmin( @@ -1906,7 +2023,7 @@ struct minimum> { result_ptr[i] = __hmin2(lhs_ptr[i], rhs_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hmin( @@ -1929,7 +2046,7 @@ struct minimum> { }; template -struct maximum> { +struct maximum, false> { CUTLASS_HOST_DEVICE Array operator()(Array const & lhs, Array const &rhs) const { Array result; @@ -1944,7 +2061,7 @@ struct maximum> { result_ptr[i] = __hmax2(lhs_ptr[i], rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); @@ -1980,7 +2097,7 @@ struct maximum> { result_ptr[i] = __hmax2(lhs_pair, rhs_ptr[i]); } - if (N % 2) { + if constexpr (N % 2) { __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); __half d_residual = __hmax( @@ -2015,7 +2132,7 @@ struct maximum> { result_ptr[i] = __hmax2(lhs_ptr[i], rhs_pair); } - if (N % 2) { + if constexpr (N % 2) { __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); __half d_residual = __hmax( @@ -2063,7 +2180,7 @@ struct multiply_add, Array, Array(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); @@ -2114,7 +2231,7 @@ struct multiply_add, Array, Array(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); @@ -2165,7 +2282,7 @@ struct multiply_add, Array, Array(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); @@ -2216,7 +2333,7 @@ struct multiply_add, Array, Array(&result); uint16_t const *a_residual_ptr = reinterpret_cast(&a); diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index b74e103889..a8a26c5011 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -35,16 +35,43 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { +namespace detail { + +// +// Utilities for abstracting synchronization methods for barriers +// + +struct SyncthreadsSync { + CUTLASS_DEVICE + static void sync() { + __syncthreads(); + } +}; + +template < + int ThreadCount, + int BarrierId +> +struct NamedBarrierSync { + CUTLASS_DEVICE + static void sync() { + cutlass::arch::NamedBarrier::sync(ThreadCount, BarrierId); + } +}; + +} // namepspace detail + ///////////////////////////////////////////////////////////////////////////////////////////////// -/// CTA-wide semaphore for inter-CTA synchronization. -struct Barrier -{ +/// Group or CTA-wide semaphore for inter-CTA synchronization. +template +struct GenericBarrier { public: @@ -111,7 +138,7 @@ struct Barrier while(ld_acquire(flag_ptr) < count) {} } - __syncthreads(); + Sync::sync(); } /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter @@ -126,7 +153,7 @@ struct Barrier #pragma unroll 1 while(ld_acquire(flag_ptr) != val) {} } - __syncthreads(); + Sync::sync(); } /// Uses thread[0] to wait for the specified count of signals on the given flag counter @@ -141,45 +168,149 @@ struct Barrier while(atomicCAS(flag_ptr, val, 0) != val) {} } - __syncthreads(); + Sync::sync(); } /// Increment the arrival count for a flag CUTLASS_DEVICE - static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx) + static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; - __syncthreads(); + Sync::sync(); if (thread_idx == 0) { - red_release(flag_ptr, 1); + red_release(flag_ptr, val); } } /// Increment the arrival counts for a range of flags CUTLASS_DEVICE - static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1) + static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { int flag_idx = first_flag_idx + thread_idx; T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; - // Barrier to make sure all other threads in block have written their data - __syncthreads(); + // Barrier to make sure all other threads in group have written their data + Sync::sync(); // Select threads increment their flags if (thread_idx < count) { - red_release(flag_ptr, 1); + red_release(flag_ptr, val); } } }; - +using Barrier = GenericBarrier; ///////////////////////////////////////////////////////////////////////////////////////////////// +/** Structure for managing multiple NamedBarriers to be used by different warp groups, allowing + * runtime index values to be used to call into named barriers with compile-time-constant IDs. + * + * @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID + * @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into +**/ +template < + uint32_t ThreadCount_, + uint32_t Offset = 0 +> +struct NamedBarrierManager { + static constexpr uint32_t MaxNumNamedBarriers = 16; + static_assert(Offset < MaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); + static constexpr uint32_t ValidBarrierIds = MaxNumNamedBarriers - Offset; + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + template + using BarrierSync = cutlass::GenericBarrier>; + + // Underlying type used by all barriers for synchronization. Does not depend on + // template parameter BarrierId, so passing in 0 suffices. + using T = typename BarrierSync<0>::T; + + using IntegerSequence = cute::make_integer_sequence; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count) { + wait_lt_helper(idx, lock_ptr, thread_idx, flag_idx, count, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + arrive_inc_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + arrive_range_inc_helper(idx, lock_ptr, thread_idx, first_flag_idx, count, val, IntegerSequence{}); + } + +private: + CUTLASS_DEVICE + static void + check_barrier_in_range(uint32_t idx) { + if (idx >= ValidBarrierIds) { + CUTE_RUNTIME_ASSERT("Index exceeds barrier count"); + } + } + + template + CUTLASS_DEVICE + static void + wait_lt_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + wait_eq_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val, cute::integer_sequence) { + check_barrier_in_range(idx); + if constexpr (Reset) { + ((Idx == idx && (BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + else { + ((Idx == idx && (BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + } + + template + CUTLASS_DEVICE + static void + arrive_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + arrive_range_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val), true)) || ...); + } +}; + } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/blas3.h b/include/cutlass/blas3.h index f5f8a0905a..90ead4a648 100644 --- a/include/cutlass/blas3.h +++ b/include/cutlass/blas3.h @@ -39,6 +39,7 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/blas3_types.h" #include "cutlass/coord.h" #include "cutlass/complex.h" #include "cutlass/functional.h" @@ -49,41 +50,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Enumerated type describing the type of kernel (based on input or output matrices). -enum class BlasMode { - kGemm, - kSymmetric, - kHermitian, - kTriangular, - kInvalid -}; - -/// Enumerated type describing the fill mode for matrices for BLAS functions. -enum class FillMode { - kFull, /// The entire tensor is covered. - kLower, /// The 'lower' part of a tensor is covered including diagonal - kUpper, /// The 'upper' part of a tensor is covered including diaognal - kDiagonal, /// Only diagonal elements are covered. - kNone, /// No element is covered. - kInvalid -}; -/// Enumerated type describing the diagonal property of matrices for BLAS functions. -enum class DiagType { - kNonUnit, - kUnit, - kZero, // Only used internally for computing SYMM/HEMM - kInvalid -}; - -/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. -enum class SideMode { - kLeft, - kRight, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines FillMode inversions template struct InvertFillMode; diff --git a/include/cutlass/blas3_types.h b/include/cutlass/blas3_types.h new file mode 100644 index 0000000000..a1df71fb8b --- /dev/null +++ b/include/cutlass/blas3_types.h @@ -0,0 +1,78 @@ +/*************************************************************************************************** + * Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing the type of kernel (based on input or output matrices). +enum class BlasMode { + kGemm, + kSymmetric, + kHermitian, + kTriangular, + kInvalid +}; + +/// Enumerated type describing the fill mode for matrices for BLAS functions. +enum class FillMode { + kFull, /// The entire tensor is covered. + kLower, /// The 'lower' part of a tensor is covered including diagonal + kUpper, /// The 'upper' part of a tensor is covered including diaognal + kDiagonal, /// Only diagonal elements are covered. + kNone, /// No element is covered. + kInvalid +}; + +/// Enumerated type describing the diagonal property of matrices for BLAS functions. +enum class DiagType { + kNonUnit, + kUnit, + kZero, // Only used internally for computing SYMM/HEMM + kInvalid +}; + +/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. +enum class SideMode { + kLeft, + kRight, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/block_striped.h b/include/cutlass/block_striped.h index 563e619d05..f86c3ea2b9 100644 --- a/include/cutlass/block_striped.h +++ b/include/cutlass/block_striped.h @@ -220,7 +220,7 @@ struct BlockStripedReduce : CUTLASS_DEVICE static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) { - cutlass::red reduce; + cutlass::atomic_add reduce; ElementT *access_output = reinterpret_cast(ptr); const ElementT *access_data = reinterpret_cast(&data); @@ -250,7 +250,7 @@ struct BlockStripedReduce : CUTLASS_DEVICE static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) { - cutlass::red reduce; + cutlass::atomic_add reduce; half2 *access_output = reinterpret_cast(ptr); const half2 *access_data = reinterpret_cast(&data); diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index b405e2e26c..ddf7571338 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -223,13 +223,21 @@ launch_kernel_on_cluster(const ClusterLaunchParams& params, { // Unfortunately, we find ourselves needing to pass in // the parameters as an array of raw pointers. - void* kernel_params[] = { - detail::checked_addressof(std::forward(args))... - }; - return cutlass::ClusterLauncher::launch( - params.grid_dims, params.cluster_dims, params.block_dims, - params.smem_size_in_bytes, params.cuda_stream, - kernel_ptr, kernel_params); + if constexpr (sizeof...(Args) == 0) { + return cutlass::ClusterLauncher::launch( + params.grid_dims, params.cluster_dims, params.block_dims, + params.smem_size_in_bytes, params.cuda_stream, + kernel_ptr, nullptr); + } + else { + void* kernel_params[sizeof...(Args)] = { + detail::checked_addressof(std::forward(args))... + }; + return cutlass::ClusterLauncher::launch( + params.grid_dims, params.cluster_dims, params.block_dims, + params.smem_size_in_bytes, params.cuda_stream, + kernel_ptr, kernel_params); + } } } // namespace cutlass diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index a3f56e4bbe..ffce5d09b7 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -199,7 +199,7 @@ class complex template CUTLASS_DEVICE void red(complex *ptr) const { static_assert(platform::is_same::value, "Component type must match"); - cutlass::red reduce; + cutlass::atomic_add reduce; reduce(&ptr->_real, _real); reduce(&ptr->_imag, _imag); } @@ -209,7 +209,7 @@ class complex static_assert(platform::is_same::value, "Component type must match"); half2 *h2_ptr = reinterpret_cast(ptr); half2 h2_data = reinterpret_cast(*this); - cutlass::red reduce; + cutlass::atomic_add reduce; reduce(h2_ptr, h2_data); } @@ -514,7 +514,6 @@ CUTLASS_HOST_DEVICE complex sin(complex const &z) { /// Comparison template CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { - //TODO return true; } @@ -679,7 +678,7 @@ struct magnitude_squared_difference, Output> { /// Reduces value into the data pointed to by ptr (complex specialization) template -struct red> { +struct atomic_add> { CUTLASS_DEVICE void operator()(complex *ptr, const complex &data) { diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index 7f800e4cb5..2984901b9d 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -73,6 +73,7 @@ Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap #pragma once #include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" #include "cutlass/tensor_coord.h" #include "cutlass/fast_math.h" #include "cutlass/gemm/gemm.h" diff --git a/include/cutlass/conv/device/direct_convolution.h b/include/cutlass/conv/device/direct_convolution.h index d7f28f10a7..af29a10a25 100644 --- a/include/cutlass/conv/device/direct_convolution.h +++ b/include/cutlass/conv/device/direct_convolution.h @@ -229,7 +229,6 @@ class DirectConvolution { if (status != cudaSuccess) return Status::kErrorInternal; - cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index 11ac967c65..2669ff7758 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -332,7 +332,7 @@ struct ImplicitGemmConvolution { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h index b740c9058f..8183d5c064 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h @@ -339,7 +339,7 @@ struct ImplicitGemmConvolutionFusion { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 7304cbdecb..6ebfcdede0 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -207,6 +207,7 @@ struct ImplicitGemmConvolutionStridedDgrad { struct Params { ConvProblemSize problem_size; cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; FastDivmod stride_h_divmod; FastDivmod stride_w_divmod; int gemm_k_iterations; @@ -259,6 +260,8 @@ struct ImplicitGemmConvolutionStridedDgrad { args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); } }; @@ -283,7 +286,7 @@ struct ImplicitGemmConvolutionStridedDgrad { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || @@ -335,7 +338,7 @@ struct ImplicitGemmConvolutionStridedDgrad { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA @@ -393,7 +396,7 @@ struct ImplicitGemmConvolutionStridedDgrad { // Compute logical position within grid threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // If performing a reduction via split-K, fetch the initial synchronization if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index 3fa7daca1b..c6e7a81326 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -341,7 +341,7 @@ struct ImplicitGemmConvolutionWithFusedEpilogue { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/conv/warp/mma_depthwise_simt.h b/include/cutlass/conv/warp/mma_depthwise_simt.h index ae49cc10c4..0ba0a0d738 100644 --- a/include/cutlass/conv/warp/mma_depthwise_simt.h +++ b/include/cutlass/conv/warp/mma_depthwise_simt.h @@ -368,7 +368,6 @@ class MmaDepthwiseDirectConvSimt { CUTLASS_DEVICE void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { - //TODO: Implement this dst_A = A; dst_B = B; } diff --git a/include/cutlass/conv/warp/scale_bias_relu_transform.h b/include/cutlass/conv/warp/scale_bias_relu_transform.h index a1a4dff4f0..d7dd567520 100644 --- a/include/cutlass/conv/warp/scale_bias_relu_transform.h +++ b/include/cutlass/conv/warp/scale_bias_relu_transform.h @@ -103,7 +103,6 @@ struct FpropScaleBiasReluTransform { : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); #else - // TODO: write emulation code assert(0); #endif } @@ -198,7 +197,6 @@ struct WgradScaleBiasReluTransform { "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); #endif #else - // TODO: write emulation code assert(0); #endif } diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index ab7b6c8d05..bbef6fc2c6 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -35,61 +35,7 @@ #pragma once -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef CUTLASS_NAMESPACE -#define concat_tok(a, b) a ## b -#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) -#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#elif defined(__CUDACC_RTC__) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#else -#define CUTLASS_HOST_DEVICE inline -#define CUTLASS_DEVICE inline -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) -{ } - -#if defined(__GNUC__) - #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) -#else - #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) -#endif - -#ifdef _MSC_VER -// Provides support for alternative operators 'and', 'or', and 'not' -#include -#endif // _MSC_VER - -#if !defined(__CUDACC_RTC__) -#include -#endif - -#if defined(__CUDA_ARCH__) - #if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } - #else - #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } - #endif -#else - #if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) - #else - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) - #endif -#endif +#include "cutlass/detail/helper_macros.hpp" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -145,47 +91,9 @@ static char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 -#endif - - -// CUDA 10.1 introduces the mma instruction -#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) -#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CUTLASS_ASSERT(x) assert(x) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. -#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) - #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) - #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") - #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") - #else - #define CUTLASS_PRAGMA_UNROLL #pragma unroll - #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 - #endif - - #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL - -#else - - #define CUTLASS_PRAGMA_UNROLL - #define CUTLASS_PRAGMA_NO_UNROLL - #define CUTLASS_GEMM_LOOP - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - static const int NumThreadsPerWarp = 32; static const int NumThreadsPerWarpGroup = 128; +static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; static const int NumThreadsPerQuad = 4; static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; @@ -201,10 +109,20 @@ CUTLASS_HOST_DEVICE bool thread0() { #endif } +/// Returns a lane index in the warp. The threads in warp may not be convergent +CUTLASS_DEVICE +int canonical_lane_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x % NumThreadsPerWarp; + #else + return 0; + #endif +} + /// Returns a warp-uniform value indicating the canonical warp index of the calling threads. /// Threads within the warp must be converged. CUTLASS_DEVICE -int canonical_warp_idx() { +int canonical_warp_idx_sync() { #if defined(__CUDA_ARCH__) return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); #else @@ -212,6 +130,17 @@ int canonical_warp_idx() { #endif } +/// Returns a warp index in the CTA. The threads in warp may not be convergent +/// As it doesn't sync the warp, it faster and allows forward progress +CUTLASS_DEVICE +int canonical_warp_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x / NumThreadsPerWarp; + #else + return 0; + #endif +} + /// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. /// Threads within the warp must be converged. CUTLASS_DEVICE diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp new file mode 100644 index 0000000000..0c3a9cd2f4 --- /dev/null +++ b/include/cutlass/detail/helper_macros.hpp @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Helper macros for the CUTLASS library +*/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +#ifdef CUTLASS_NAMESPACE +#define concat_tok(a, b) a ## b +#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) +#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#else +#define CUTLASS_HOST_DEVICE inline +#define CUTLASS_DEVICE inline +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) +{ } + +#if defined(__GNUC__) + #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) +#else + #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) +#endif + +#ifdef _MSC_VER +// Provides support for alternative operators 'and', 'or', and 'not' +#include +#endif // _MSC_VER + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#if defined(__CUDA_ARCH__) + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } + #else + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } + #endif +#else + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) + #else + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) + #endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + + +#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED +#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 +#endif + + +// CUDA 10.1 introduces the mma instruction +#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) +#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_ASSERT(x) assert(x) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. +#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) + #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) + #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") + #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") + #else + #define CUTLASS_PRAGMA_UNROLL #pragma unroll + #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 + #endif + + #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL + +#else + + #define CUTLASS_PRAGMA_UNROLL + #define CUTLASS_PRAGMA_NO_UNROLL + #define CUTLASS_GEMM_LOOP + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +#define CUTLASS_THREAD_LOCAL thread_local +#else +#define CUTLASS_THREAD_LOCAL +#endif + +}; // namespace cutlass diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp new file mode 100644 index 0000000000..da76f0d655 --- /dev/null +++ b/include/cutlass/detail/layout.hpp @@ -0,0 +1,244 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" + +#include "cute/layout.hpp" +#include "cute/arch/copy_sm90_tma.hpp" +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For each cutlass::layout, provides its corresponding cute stride types, 64b by default + +template +struct TagToStrideA { + using type = L; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::ColumnMajor; +}; + +template +struct TagToStrideB { + using type = L; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t>; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [M, N, L] +template +struct TagToStrideC : TagToStrideA { }; + +// Convenience aliases +template +using TagToStrideA_t = typename TagToStrideA::type; + +template +using TagToStrideB_t = typename TagToStrideB::type; + +template +using TagToStrideC_t = typename TagToStrideC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For 2.x compatibility APIs, provide stride->layout tag mappers + +template +constexpr bool +is_major(Stride = {}) { + // Account for stride types with and without batch mode and batch modes with static zero stride + return cute::is_constant<1, decltype(cute::front(cute::get(Stride{})))>::value; +} + +// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices +template +constexpr +auto +stride_to_layout_tag_A() { + if constexpr (is_major<0, StrideA>()) { // M major + return layout::ColumnMajor{}; + } + else { // K major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_B() { + if constexpr (is_major<0, StrideB>()) { // N major + return layout::RowMajor{}; + } + else { // K major + return layout::ColumnMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_C() { + if constexpr (is_major<0, StrideC>()) { // M major + return layout::ColumnMajor{}; + } + else { // N major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// Utilities to map Stride back on to their corresponding layout tags +template +struct StrideToLayoutTagA { + using type = decltype(detail::stride_to_layout_tag_A()); +}; + +template +struct StrideToLayoutTagB { + using type = decltype(detail::stride_to_layout_tag_B()); +}; + +template +struct StrideToLayoutTagC { + using type = decltype(detail::stride_to_layout_tag_C()); +}; + +// Convenience aliases +template +using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; + +template +using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; + +template +using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Inspects a tiled copy and whether its copy engine is TMA or not +template +constexpr bool is_tma_copy_engine() { + if constexpr (cute::is_void_v) { + return false; + } + else { + if constexpr ( cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + ) { + return true; + } + } + return false; +} + +// Inspects a TiledCopy and returns its alignment in terms of element count +template +constexpr int +get_alignment_count_from_gmem_tiled_copy() { + if constexpr (cute::is_void_v) { + return 1; + } + + // Account for ElementC = void kernels + else if constexpr (cute::is_void_v) { + return 0; + } + + else { + // For TMA tiled copies, we know the alignment has to be 128 bits + if constexpr (is_tma_copy_engine()) { + return 128 / sizeof_bits::value; + } + else { + // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN + return GmemTiledCopy::NumValSrc; + } + } +} + +// Return the shape that is associated with stride-1 mode, or 1 if not found +template +CUTLASS_HOST_DEVICE constexpr +auto +get_contiguous_shape(Shape const & shape, Stride const & stride) { + using namespace cute; + auto idx = find_if(append(flatten(stride), _1{}), [](auto s){ return is_constant<1,decltype(s)>{}; }); + return get(append(flatten(shape), _1{})); +} + +// Check if tensor shape satisfies a given major alignment +template +CUTLASS_HOST_DEVICE constexpr +bool +check_alignment(Shape const & shape, Stride const & stride) { + return is_major<0>(stride) + ? get_contiguous_shape(cute::get<0>(shape), cute::get<0>(stride)) % Alignment == 0 + : get_contiguous_shape(cute::get<1>(shape), cute::get<1>(stride)) % Alignment == 0; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index cde9f1ff6c..c019dfecd1 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -64,7 +64,6 @@ __global__ void Kernel(typename Operator::Params params) { // Dynamic shared memory base pointer extern __shared__ int SharedStorageBase[]; - // Declare pointer to dynamic shared memory. typename Operator::SharedStorage *shared_storage = reinterpret_cast(SharedStorageBase); @@ -81,13 +80,11 @@ __global__ void Kernel2(typename Operator::Params params) { // Dynamic shared memory base pointer extern __shared__ int SharedStorageBase[]; - // Declare pointer to dynamic shared memory. typename Operator::SharedStorage *shared_storage = reinterpret_cast(SharedStorageBase); Operator::invoke(params, *shared_storage); - } @@ -108,7 +105,6 @@ void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) { // Dynamic shared memory base pointer extern __shared__ char smem[]; - Operator op; op(params, smem); } diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index 045b05f38e..e94a025a94 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -35,13 +35,15 @@ #include "cute/atom/copy_traits_sm90.hpp" #include "cutlass/detail/dependent_false.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/collective_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/epilogue/thread/linear_combination_generic.h" #include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" #if defined(__CUDACC_RTC__) #include @@ -57,6 +59,21 @@ namespace cutlass::epilogue::collective { namespace detail { +// Returns the parameterized dispatch policy for the TMA epilogue +template +constexpr auto +sm90_get_tma_dispatch_policy() { + using namespace cute; + + constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v ? 256 : 128); + constexpr int ReuseSmemC = sizeof_bits_v == sizeof_bits_v; + constexpr int StagesD = 2; + constexpr int StagesC = ReuseSmemC ? cute::max(EpiTiles, StagesD + 1) : EpiTiles; + + return Sm90TmaWarpSpecialized{}; +} + // Returns the smem layout atom to be used for C or D matrix template constexpr auto @@ -64,13 +81,13 @@ sm90_get_epilogue_smem_swizzle_layout_atom() { using namespace cute; // ColMajor C/D (M-major) - if constexpr (size<0>(GmemStrideType{}) == 1) { + if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) { return cutlass::gemm::collective::detail::ss_smem_selector< cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) >(); } // RowMajor C/D (N-major) - else if constexpr (size<1>(GmemStrideType{}) == 1) { + else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) { return cutlass::gemm::collective::detail::ss_smem_selector< cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) >(); @@ -87,7 +104,7 @@ sm90_compute_tile_shape_or_override() { if constexpr (cute::is_same_v) { if constexpr (detail::sm90_is_cooperative_v) { - return Shape<_128,_16>{}; + return Shape<_128,_32>{}; } else if constexpr (detail::sm90_is_warp_specialized_v) { return Shape<_64,_32>{}; @@ -104,7 +121,7 @@ sm90_compute_tile_shape_or_override() { static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); static_assert(M == 64 && detail::sm90_is_warp_specialized_v || M == 128 && detail::sm90_is_cooperative_v, "Unsupported tile shape"); - static_assert(N % 8 == 0, "Unsupported tile shape"); + static_assert(N % 16 == 0, "Unsupported tile shape"); return epi_tile; } @@ -152,12 +169,43 @@ sm90_get_smem_load_op_for_source() { } } +// callbacks builder with TMA aux out +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + enable_if_t +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + // Helper for building TMA warp-specialized collective epilogues, specialized by -// the thread-level epilogue operation performed and the dispatch policy to use. +// the fusion operation performed and the dispatch policy to use. template < class TileShape_MNK, - class ClusterShape_MNK, - class EpilogueTileType, + class EpilogueTile_MN, class ElementAccumulator, class ElementCompute, class ElementC_, @@ -166,21 +214,27 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule, - class ThreadOp, + class FusionOpOrCallbacks, class DispatchPolicy > -struct TmaBuilderImpl { - - // Passing void C disables source load +struct Sm90TmaBuilderImpl { + // Passing void C disables source load + smem allocation using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; - using GmemStrideTypeC = gemm::TagToStrideC_t; - using GmemStrideTypeD = gemm::TagToStrideC_t; + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; - using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override< - ElementD, EpilogueTileType, Schedule>()); + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks + // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination + using FusionCallbacks = + typename CallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator + >::Callbacks; using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< DispatchPolicy, @@ -190,9 +244,9 @@ struct TmaBuilderImpl { GmemStrideTypeC, ElementD, GmemStrideTypeD, - ThreadOp, + FusionCallbacks, SM90_TMA_LOAD, - decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), decltype(detail::sm90_get_smem_load_op_for_source()), SM90_TMA_STORE, decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), @@ -217,7 +271,7 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule + FloatRoundStyle RoundStyle > struct CollectiveBuilder< arch::Sm90, @@ -233,8 +287,9 @@ struct CollectiveBuilder< ElementD, GmemLayoutTagD, AlignmentD, - Schedule, - cute::enable_if_t>> { + NoSmemWarpSpecialized, + fusion::LinearCombination, + void> { // Passing void C disables source load using ElementC = cute::conditional_t, @@ -247,12 +302,12 @@ struct CollectiveBuilder< static constexpr int FragmentSize = 1; using ThreadOp = thread::LinearCombination< ElementD, FragmentSize, ElementAccumulator, ElementCompute, - ScaleType, FloatRoundStyle::round_to_nearest, ElementC>; + ScaleType, RoundStyle, ElementC>; using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, ThreadOp, cutlass::gemm::EpilogueDefault> >; @@ -265,13 +320,14 @@ template < class EpilogueTileType, class ElementAccumulator, class ElementCompute, - class ElementC_, + class ElementC, class GmemLayoutTagC, int AlignmentC, class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule + class Schedule, + class FusionOperation > struct CollectiveBuilder< arch::Sm90, @@ -281,36 +337,38 @@ struct CollectiveBuilder< EpilogueTileType, ElementAccumulator, ElementCompute, - ElementC_, + ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, Schedule, + FusionOperation, cute::enable_if_t || cute::is_same_v >> { -public: - using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages - static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? - thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; - - static constexpr int FragmentSize = 4; - using ThreadOp = thread::LinearCombination< - ElementD, FragmentSize, ElementAccumulator, ElementCompute, - ScaleType, FloatRoundStyle::round_to_nearest, ElementC>; - private: - static constexpr int StagesC = 1; - static constexpr int StagesD = 2; - static constexpr bool DisableReuseSmemC = true; - using Impl = detail::TmaBuilderImpl< - TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, - ElementC_, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, - Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecialized>; + using EpilogueTile_MN = + decltype(detail::sm90_compute_tile_shape_or_override()); + using DispatchPolicy = + decltype(detail::sm90_get_tma_dispatch_policy()); public: - using CollectiveOp = typename Impl::CollectiveOp; + using CollectiveOp = + typename detail::Sm90TmaBuilderImpl< + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + FusionOperation, + DispatchPolicy + >::CollectiveOp; }; // Auto builder @@ -326,7 +384,7 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule + class FusionOperation > struct CollectiveBuilder< arch::Sm90, @@ -342,18 +400,13 @@ struct CollectiveBuilder< ElementD, GmemLayoutTagD, AlignmentD, - Schedule, - cute::enable_if_t>> { - + EpilogueScheduleAuto, + FusionOperation, + void> { private: - static constexpr bool IsTmaAligned = cutlass::gemm::collective::detail::is_aligned< - ElementC, AlignmentC, ElementD, AlignmentD, cutlass::gemm::collective::detail::tma_alignment_bytes>(); - - // Current TMA epilogues require sixteen-bit data types and epilogue tile M to be of size 64. - // Only dispatch to the TMA builder if these requirements are satisfied. - static constexpr bool IsSixteenBit = sizeof_bits::value == 16 && sizeof_bits::value == 16; - static constexpr bool IsEpiTileM64 = size<0>(shape(TileShape_MNK{})) == 64; - + // Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance) + // since TMA epilogues are not compatible with non-TMA non-WS mainloops + using EpilogueSchedule = NoSmemWarpSpecialized; using _CollectiveBuilder = CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, @@ -368,16 +421,14 @@ private: ElementD, GmemLayoutTagD, AlignmentD, - cute::conditional_t + EpilogueSchedule >; public: - using ThreadOp = typename _CollectiveBuilder::ThreadOp; using CollectiveOp = typename _CollectiveBuilder::CollectiveOp; }; -// Tma warp-specialized builder for elementwise fusion +// DEPRECATED Tma warp-specialized builder for elementwise fusion template < class TileShape_MNK, class ClusterShape_MNK, @@ -390,9 +441,11 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule + class Schedule, + class UnusedFusionOp > -struct CollectiveBuilder< +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, TileShape_MNK, @@ -407,30 +460,38 @@ struct CollectiveBuilder< GmemLayoutTagD, AlignmentD, Schedule, + UnusedFusionOp, cute::enable_if_t || cute::is_base_of_v >> { -public: - static constexpr int FragmentSize = 4; - using ThreadOp = thread::LinearCombinationGeneric< - Schedule::ActivationFunctor, - ElementD, FragmentSize, - ElementAccumulator, ElementCompute, Schedule::Scale, - Schedule::Round>; - private: - static constexpr int StagesC = 1; - static constexpr int StagesD = 2; - static constexpr bool DisableReuseSmemC = true; - using Impl = detail::TmaBuilderImpl< - TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, - ElementC, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, - Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecialized>; + using FusionOp = + fusion::LinCombEltAct; + using ImplSchedule = + cute::conditional_t, + TmaWarpSpecialized, TmaWarpSpecializedCooperative>; public: - using CollectiveOp = typename Impl::CollectiveOp; + using CollectiveOp = + typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + ImplSchedule, + FusionOp + >::CollectiveOp; }; -// Tma warp-specialized builder for bias + elementwise fusion +// DEPRECATED Tma warp-specialized builder for bias + elementwise fusion template < class TileShape_MNK, class ClusterShape_MNK, @@ -438,14 +499,16 @@ template < class ElementAccumulator, class ElementCompute, class ElementC_, - class GmemLayoutTagC, + class GmemLayoutTagC_, int AlignmentC, class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule + class Schedule, + class UnusedFusionOp > -struct CollectiveBuilder< +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]] +CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, TileShape_MNK, @@ -454,35 +517,75 @@ struct CollectiveBuilder< ElementAccumulator, ElementCompute, ElementC_, - GmemLayoutTagC, + GmemLayoutTagC_, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, Schedule, + UnusedFusionOp, cute::enable_if_t || cute::is_base_of_v >> { private: - // Passing void C disables source load - using ElementC = cute::conditional_t, ElementD, ElementC_>; // prevents void ref breakages + using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule>()); + // MSVC doesn't seem to be able to deduce DispatchPolicy correctly if it's + // defined as decltype of a detail::sm90_get_tma_dispatch_policy call. + // Instead, we paste in the contents of that function. A natural refactoring + // would be to create a type alias in the detail namespace. + using DispatchPolicy = Sm90TmaWarpSpecialized< + /* StagesC = */ size(shape_div(take<0, 2>(TileShape_MNK{}), EpilogueTile_MN{})), + /* StagesD = */ 2, + /* FragmentSize = */ size(EpilogueTile_MN{}) / (detail::sm90_is_cooperative_v ? 256 : 128), + /* ReuseSmemC = */ sizeof_bits_v == sizeof_bits_v + >; -public: - static constexpr int FragmentSize = 4; - using ThreadOp = thread::LinearCombinationBiasElementwise< - ElementC, ElementAccumulator, ElementCompute, ElementD, typename Schedule::ElementT, FragmentSize, - typename Schedule::ActivationFunctor, typename Schedule::BiasOp, - Schedule::StoreT, typename Schedule::ElementBias>; + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename Schedule::ElementT, EpilogueTile_MN>()); + using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename Schedule::ElementT>()); + using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagD, Schedule::ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementT, typename Schedule::ElementBias, ElementCompute + >; + using FusionCallbacksAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux + >; -private: - static constexpr int StagesC = 1; - static constexpr int StagesD = 2; - using Impl = detail::TmaBuilderImpl< - TileShape_MNK, ClusterShape_MNK, EpilogueTileType, ElementAccumulator, ElementCompute, - ElementC_, GmemLayoutTagC, AlignmentC, ElementD, GmemLayoutTagD, AlignmentD, - Schedule, ThreadOp, cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise>; + using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct< + Schedule::ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementBias, ElementCompute + >; + using FusionCallbacksNoAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN + >; + + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = gemm::TagToStrideC_t; + using GmemStrideTypeD = gemm::TagToStrideC_t; public: - using CollectiveOp = typename Impl::CollectiveOp; + using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise< + DispatchPolicy::StagesC, + DispatchPolicy::StagesD, + DispatchPolicy::FragmentSize, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + cute::conditional_t, + SM90_TMA_LOAD, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + SM90_TMA_STORE, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()) + >; }; // CollectiveBuilder that transposed epilogue below is used for sm90 gmma RS TT kernels @@ -500,7 +603,7 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule + FloatRoundStyle RoundStyle > struct CollectiveBuilder< arch::Sm90, @@ -516,8 +619,9 @@ struct CollectiveBuilder< ElementD, GmemLayoutTagD, AlignmentD, - Schedule, - cute::enable_if_t>> { + cutlass::gemm::EpilogueTransposed, + fusion::LinearCombination, + void> { // Passing void C disables source load using ElementC = cute::conditional_t, ElementD, ElementC_>; // prevents cute breakages @@ -529,12 +633,12 @@ struct CollectiveBuilder< static constexpr int FragmentSize = 1; using ThreadOp = thread::LinearCombination< ElementD, FragmentSize, ElementAccumulator, ElementCompute, - ScaleType, FloatRoundStyle::round_to_nearest, ElementC>; + ScaleType, RoundStyle, ElementC>; using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< cutlass::epilogue::collective::DefaultEpilogue< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, ThreadOp, cutlass::gemm::EpilogueTransposed> >; diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index d71b7a30e1..46ad166b2e 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -31,6 +31,7 @@ #pragma once #include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -60,6 +61,7 @@ template < class GmemLayoutTagD, int AlignmentD, class Schedule, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, class Enable = void > struct CollectiveBuilder { @@ -67,6 +69,43 @@ struct CollectiveBuilder { "Could not build a collective epilogue for given parameters."); }; +// helper sub-builder for epilogue fusion callbacks (for internal use by CollectiveBuilder only) +namespace detail { + +// callbacks builder with operation tag +template< + class DispatchPolicy, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class = void +> +struct CallbacksBuilder { + using Callbacks = fusion::FusionCallbacks; +}; + +// callbacks builder with callbacks passthrough +template < + class DispatchPolicy, + class FusionCallbacks, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + DispatchPolicy, + FusionCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + enable_if_t> +> { + using Callbacks = FusionCallbacks; +}; + +} // namespace detail + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::collective diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp index 37bb79b032..36ccdce0be 100644 --- a/include/cutlass/epilogue/collective/collective_epilogue.hpp +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -1,24 +1,30 @@ /*************************************************************************************************** - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index dcb47d9742..aea8721d66 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -86,7 +86,7 @@ class DefaultEpilogue { struct SharedStorage { }; - // Host side epilgoue arguments + // Host side epilogue arguments struct Arguments { typename ThreadEpilogueOp::Params thread{}; ElementC const* ptr_C = nullptr; @@ -111,6 +111,14 @@ class DefaultEpilogue { return args; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + CUTLASS_HOST_DEVICE DefaultEpilogue(Params const& params_) : params(params_), epilogue_op(params_.thread) { } diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 033f5ccc5f..af77479c77 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -33,11 +33,13 @@ #include "cutlass/cutlass.h" #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/epilogue/dispatch_policy.hpp" #include "cute/tensor.hpp" #include "cute/numeric/int.hpp" +#include "cute/util/type_traits.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -47,16 +49,32 @@ namespace collective { namespace detail { +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool +is_m_major() { + return cutlass::gemm::detail::is_major<0,Stride>(); +} + +template +constexpr bool +is_n_major() { + return cutlass::gemm::detail::is_major<1,Stride>(); +} + +using cutlass::atomic_maximum; + template static constexpr int elements_per_access_v = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; template static constexpr bool sm90_is_cooperative_v = - std::is_base_of_v; + cute::is_base_of_v; template static constexpr bool sm90_is_warp_specialized_v = - std::is_base_of_v; + cute::is_base_of_v; template struct EmptyStorage { @@ -87,25 +105,19 @@ struct IsThreadEpilogueOpWithBias ::value will be true only if: -// class T has member CopyOpS2G and T::CopyOpS2G is true -template -struct IF_EPILOGUE_USES_TMA { static constexpr bool value = false; }; - -template -struct IF_EPILOGUE_USES_TMA > -{ static constexpr bool value = true; }; - // Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels template class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { public: - using LoadPipeline = cutlass::PipelineTransactionAsync<0>; // 0 stage to disable smem alloc + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using LoadPipeline = cutlass::PipelineTransactionAsync<0>; using LoadPipelineState = cutlass::PipelineState<0>; constexpr static uint32_t TmaTransactionBytes = 0; - using StorePipeline = cutlass::PipelineTmaStore<1>; // tma store pipe has no smem alloc - using StorePipelineState = cutlass::PipelineState<1>; + using StorePipeline = cutlass::PipelineTmaStore<0>; + using StorePipelineState = cutlass::PipelineState<0>; using TensorStorage = typename EpilogueOp::SharedStorage; using PipelineStorage = typename LoadPipeline::SharedStorage; @@ -114,34 +126,45 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { CUTLASS_HOST_DEVICE static constexpr int get_load_pipe_increment([[maybe_unused]] TileShapeMNK) { - return 1; + return 0; } template CUTLASS_HOST_DEVICE static constexpr int get_store_pipe_increment([[maybe_unused]] TileShapeMNK) { - return 1; + return 0; } CUTLASS_DEVICE - static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) - { + static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { } // ctor inheritance using EpilogueOp::EpilogueOp; + CUTLASS_HOST_DEVICE + Sm90TmaWarpSpecializedAdapter( + typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorStorage& shared_tensors) + : EpilogueOp(params) { } + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return false; + } + template< class ProblemShapeMNKL, class TileShapeMNK, class TileCoordMNKL, class TiledMma > - CUTLASS_DEVICE void + CUTLASS_DEVICE auto load( [[maybe_unused]] LoadPipeline load_pipeline, - [[maybe_unused]] LoadPipelineState load_pipe_producer_state, + LoadPipelineState load_pipe_producer_state, [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, [[maybe_unused]] TileShapeMNK tile_shape_MNK, [[maybe_unused]] TileCoordMNKL tile_coord_mnkl, @@ -149,14 +172,15 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { [[maybe_unused]] int thread_idx, [[maybe_unused]] TensorStorage& shared_tensors) { - // source load is performed in epilogue operator + return load_pipe_producer_state; } - CUTLASS_DEVICE void + CUTLASS_DEVICE auto load_tail( [[maybe_unused]] LoadPipeline load_pipeline, - [[maybe_unused]] LoadPipelineState load_pipe_producer_state) + LoadPipelineState load_pipe_producer_state) { + return load_pipe_producer_state; } template< @@ -166,12 +190,12 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { class AccEngine, class AccLayout, class TiledMma > - CUTLASS_DEVICE void + CUTLASS_DEVICE auto store( [[maybe_unused]] LoadPipeline load_pipeline, - [[maybe_unused]] LoadPipelineState load_pipe_consumer_state, + LoadPipelineState load_pipe_consumer_state, [[maybe_unused]] StorePipeline store_pipeline, - [[maybe_unused]] StorePipelineState store_pipe_producer_state, + StorePipelineState store_pipe_producer_state, ProblemShapeMNKL problem_shape_mnkl, TileShapeMNK tile_shape_MNK, TileCoordMNKL tile_coord_mnkl, @@ -201,6 +225,17 @@ class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { residue_mnk, thread_idx, reinterpret_cast(&shared_tensors)); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); } }; diff --git a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp index 512ce1a841..70edf77d5f 100644 --- a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp +++ b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -130,6 +130,14 @@ class EpilogueTensorBroadcast { return args; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + CUTLASS_HOST_DEVICE EpilogueTensorBroadcast(Params const& params_) : params(params_), epilogue_op(params_.thread) { } diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 509f3b94d1..0374a1036b 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -120,6 +120,14 @@ class Epilogue { return args; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + CUTLASS_HOST_DEVICE Epilogue(Params const& params_) : params(params_), epilogue_op(params_.thread) { } diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index a52bb2b8a7..5bdfab882f 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -1,24 +1,30 @@ /*************************************************************************************************** - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -33,6 +39,8 @@ #include "cutlass/epilogue/dispatch_policy.hpp" #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/trace.h" #include "cute/tensor.hpp" @@ -47,14 +55,15 @@ namespace collective { template < int StagesC_, int StagesD_, - bool DisableSmemReuseC_, - class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) - class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N) + int FragmentSize_, + bool ReuseSmemC_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) class ElementC_, class StrideC_, class ElementD_, class StrideD_, - class ThreadEpilogueOp_, + class FusionCallbacks_, class CopyOpG2S_, class SmemLayoutAtomC_, class CopyOpS2R_, @@ -63,14 +72,14 @@ template < class CopyOpR2S_ > class CollectiveEpilogue< - Sm90TmaWarpSpecialized, - BlockTileShape_, - EpilogueTileShape_, + Sm90TmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, ElementC_, StrideC_, ElementD_, StrideD_, - ThreadEpilogueOp_, + FusionCallbacks_, CopyOpG2S_, SmemLayoutAtomC_, CopyOpS2R_, @@ -82,20 +91,14 @@ class CollectiveEpilogue< // // Type Aliases // - using DispatchPolicy = Sm90TmaWarpSpecialized; - using BlockTileShape = BlockTileShape_; - using EpilogueTileShape = EpilogueTileShape_; - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using DispatchPolicy = Sm90TmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; using ElementC = ElementC_; using StrideC = StrideC_; using ElementD = ElementD_; using StrideD = StrideD_; - using CopyOpG2S = CopyOpG2S_; using SmemLayoutAtomC = SmemLayoutAtomC_; using CopyOpS2R = CopyOpS2R_; @@ -103,55 +106,58 @@ class CollectiveEpilogue< using SmemLayoutAtomD = SmemLayoutAtomD_; using CopyOpR2S = CopyOpR2S_; + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; using GmemTiledCopyC = SM90_TMA_LOAD; using GmemTiledCopyD = SM90_TMA_STORE; - constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; - constexpr static bool iskThreadEpilogueOpWithBias = detail::IsThreadEpilogueOpWithBias::value; - using AlignmentType = typename uint_bit::value * kOutputAlignment>::type; - - static_assert(!is_layout::value && is_tuple::value, "EpilogueTileShape must be a cute::Shape"); - static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]"); - static_assert(rank(EpilogueTileShape{}) == 2, "EpilogueTileShape must be rank-2: [EPI_TILE_M,EPI_TILE_N]"); - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); private: using InternalElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages constexpr static int StagesC = StagesC_; constexpr static int StagesD = StagesD_; - constexpr static bool is_source_supported = ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default || - ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::NoBetaScaling; - static_assert((cute::is_void_v && not is_source_supported) || (not cute::is_void_v && is_source_supported), - "Inconsistent C type and Scale kind"); + constexpr static bool is_source_supported = not cute::is_void_v; // internal optimization to reuse C shared memory for storing D using SmemLayoutAtomBitsC = decltype(downcast::value>(SmemLayoutAtomC{})); using SmemLayoutAtomBitsD = decltype(downcast::value>(SmemLayoutAtomD{})); - constexpr static bool ReuseSmemC = not DispatchPolicy::DisableSmemReuseC && - is_source_supported && - sizeof(InternalElementC) == sizeof(ElementD) && - StrideC{} == StrideD{} && - cute::is_same_v; + constexpr static bool support_smem_reuse = is_source_supported && + sizeof(InternalElementC) == sizeof(ElementD) && + StrideC{} == StrideD{} && + StagesD <= StagesC && + cute::is_same_v; + constexpr static bool ReuseSmemC = DispatchPolicy::ReuseSmemC; + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); public: using SmemLayoutC = decltype(tile_to_shape( SmemLayoutAtomC{}, - make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int{}), - cute::conditional_t(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); using SmemLayoutD = decltype(tile_to_shape( SmemLayoutAtomD{}, - make_shape(size<0>(EpilogueTileShape{}), size<1>(EpilogueTileShape{}), Int{}), - cute::conditional_t(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); // TMA pipeline for loading C - using LoadPipeline = cutlass::PipelineTransactionAsync; - using LoadPipelineState = cutlass::PipelineState; + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; constexpr static uint32_t TmaTransactionBytes = size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof(InternalElementC)); // TMA pipeline for storing D - using StorePipeline = cutlass::PipelineTmaStore; + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; using StorePipelineState = cutlass::PipelineState; struct SharedStorage { @@ -162,6 +168,9 @@ class CollectiveEpilogue< alignas(128) cute::conditional_t, array_aligned> smem_D; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + alignas(128) FusionStorage thread; } tensors; using PipelineStorage = typename LoadPipeline::SharedStorage; @@ -172,14 +181,14 @@ class CollectiveEpilogue< // Host side epilogue arguments struct Arguments { - typename ThreadEpilogueOp::Params thread; + typename FusionCallbacks::Arguments thread{}; ElementC const* ptr_C; StrideC dC; ElementD const* ptr_D; StrideD dD; }; - // Device side epilgoue params + // Device side epilogue params struct Params { using TMA_C = decltype(make_tma_copy( CopyOpG2S{}, @@ -192,7 +201,7 @@ class CollectiveEpilogue< repeat_like(StrideD{}, int32_t(0)), StrideD{}), SmemLayoutD{}(_,_,0))); - typename ThreadEpilogueOp::Params thread{}; + typename FusionCallbacks::Params thread{}; TMA_C tma_load_c; TMA_D tma_store_d; }; @@ -207,24 +216,15 @@ class CollectiveEpilogue< ProblemShape const& problem_shape, Arguments const& args, [[maybe_unused]] void* workspace) { - // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - typename Params::TMA_C tma_load_c = [&]() { - if constexpr (not cute::is_void_v) { - Tensor tensor_c = make_tensor(static_cast(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); - return make_tma_copy( - CopyOpG2S{}, - tensor_c, - SmemLayoutC{}(_,_,0)); - } - else { - return typename Params::TMA_C{}; - } - }(); + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + typename Params::TMA_C tma_load_c; + if constexpr (not cute::is_void_v) { + Tensor tensor_c = make_tensor(args.ptr_C, make_layout(make_shape(M,N,L), args.dC)); + tma_load_c = make_tma_copy(CopyOpG2S{}, tensor_c, SmemLayoutC{}(_,_,0)); + } Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); typename Params::TMA_D tma_store_d = make_tma_copy( @@ -233,19 +233,43 @@ class CollectiveEpilogue< SmemLayoutD{}(_,_,0)); return { - args.thread, + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), tma_load_c, tma_store_d }; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); + + if constexpr (not cute::is_void_v) { + constexpr int min_tma_aligned_elements_C = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideC{}); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + return implementable; + } + template CUTLASS_HOST_DEVICE static constexpr int get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { - // Compute number of C subtiles (currently always one) - constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutC{}); - constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutC{}); + // Compute number of epilogue subtiles + constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(EpilogueTile{}); + constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(EpilogueTile{}); return epi_m * epi_n; } @@ -254,32 +278,25 @@ class CollectiveEpilogue< CUTLASS_HOST_DEVICE static constexpr int get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { - if constexpr (ReuseSmemC) { - return get_load_pipe_increment(tile_shape_MNK); - } + return get_load_pipe_increment(tile_shape_MNK); + } - // Compute number of D subtiles - constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutD{}); - constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutD{}); - - return epi_m * epi_n; + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void + prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); } CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_) - : params(params_), epilogue_op(params_.thread) { } + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} CUTLASS_DEVICE bool - is_source_needed() { - return epilogue_op.is_source_needed(); - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& epilogue_params) { - cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); - cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); } template< @@ -288,7 +305,7 @@ class CollectiveEpilogue< class TileCoordMNKL, class TiledMma > - CUTLASS_DEVICE void + CUTLASS_DEVICE auto load( LoadPipeline load_pipeline, LoadPipelineState load_pipe_producer_state, @@ -296,55 +313,90 @@ class CollectiveEpilogue< TileShapeMNK tile_shape_MNK, TileCoordMNKL tile_coord_mnkl, TiledMma tiled_mma, - [[maybe_unused]] int thread_idx, + int thread_idx, TensorStorage& shared_tensors) { using namespace cute; - using X = Underscore; - - int warp_idx = canonical_warp_idx(); - int warp_idx_in_warp_group = warp_idx % 4; - int lane_predicate = cute::elect_one_sync(); + using _X = Underscore; - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - // Represent the full source tensor - Tensor mC_mnl = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (TILE_M,TILE_N,m,n,l) - // Slice to get the gmem tile of C (gC) this CTA is currently responsible for - Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) - // Get the corresponding smem tile of C (sC) - Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), SmemLayoutC{}); // (TILE_M,TILE_N,PIPE) + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mnl = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gC_mnl = local_tile(mC_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1,_X>{}); // (CTA_M,CTA_N,m,n,l) + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (CTA_M,CTA_N) - // Prepare the thread(b)lock (G)mem to (S)mem TMA copy (bGS_) + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = make_smem_ptr(shared_tensors.smem_C.data()); + Tensor gC_epi = local_tile(gC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(ptr_sC, SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); - Tensor bGS_gC = thrblk_g2s.partition_S(gC); // (TMA,TMA_M,TMA_N) - Tensor bGS_sC = thrblk_g2s.partition_D(sC); // (TMA,TMA_M,TMA_N,PIPE) + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + EpilogueTile{}, + thread_idx); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); - auto* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - uint16_t mcast_mask = 0; + // Acquire the lock for the first stage + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(tma_barrier, load_pipe_producer_state.count(), issue_tma_load); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } - // Execute the TMA load for C - if (warp_idx_in_warp_group == 0 and lane_predicate) { - load_pipeline.producer_acquire(load_pipe_producer_state); - copy(params.tma_load_c.with(*tma_barrier, mcast_mask), bGS_gC, bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_commit(load_pipe_producer_state); + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; } - CUTLASS_DEVICE void + CUTLASS_DEVICE auto load_tail( LoadPipeline load_pipeline, LoadPipelineState load_pipe_producer_state) { - int warp_idx = canonical_warp_idx(); - int warp_idx_in_warp_group = warp_idx % 4; - int lane_predicate = cute::elect_one_sync(); - - if (warp_idx_in_warp_group == 0 and lane_predicate) { + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { load_pipeline.producer_tail(load_pipe_producer_state); } + + return load_pipe_producer_state; } template< @@ -354,7 +406,7 @@ class CollectiveEpilogue< class AccEngine, class AccLayout, class TiledMma > - CUTLASS_DEVICE void + CUTLASS_DEVICE auto store( LoadPipeline load_pipeline, LoadPipelineState load_pipe_consumer_state, @@ -368,7 +420,8 @@ class CollectiveEpilogue< int thread_idx, TensorStorage& shared_tensors) { using namespace cute; - using X = Underscore; + using _X = Underscore; + using ElementAccumulator = typename AccEngine::value_type; static_assert(is_rmem::value, "Accumulator must be RF resident."); static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); @@ -377,71 +430,78 @@ class CollectiveEpilogue< static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{}); auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{}); - auto epi_tile_m = size<0>(EpilogueTileShape{}); - auto epi_tile_n = size<1>(EpilogueTileShape{}); - - // Represent the full output tensor - Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (TILE_M,TILE_N,m,n,l) - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) - - // Construct the smem tensors for source (sC) and output (sD) - Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), // (TILE_M,TILE_N) - SmemLayoutC{})(_,_,load_pipe_consumer_state.index()); - Tensor bEsD = make_tensor(make_smem_ptr(shared_tensors.smem_D.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE) - SmemLayoutD{}); - - // Tile thread(b)lock tensors by (E)pilogue output tile shape (bE) - Tensor bEsC = local_tile(sC, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor bEgD = local_tile(gD, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Partition for register to smem copy (tRS_) - using CopyAtomR2S = cute::conditional_t, - Copy_Atom>,ElementD>, - Copy_Atom>; - TiledCopy tiled_r2s = make_tiled_copy_C_atom(CopyAtomR2S{}, tiled_mma); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gD_mnl = local_tile(mD_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1,_X>{}); // (CTA_M,CTA_N,m,n,l) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (CTA_M,CTA_N) + + // Apply epilogue subtiling, construct corresponding pipelined smem tensors + auto ptr_sC = make_smem_ptr(shared_tensors.smem_C.data()); + auto ptr_sD = make_smem_ptr(shared_tensors.smem_D.data()); + Tensor gD_epi = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(ptr_sC, SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = make_tensor(ptr_sD, SmemLayoutD{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) Tensor tRS_sD = conditional_return( - thread_r2s.partition_D(recast(bEsC)), // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - thread_r2s.partition_D(bEsD) ); // (R2S,R2S_M,R2S_N,PIPE) + thread_r2s.partition_D(recast(sC_epi)), // (R2S,R2S_M,R2S_N,PIPE_C) + thread_r2s.partition_D(sD_epi) ); // (R2S,R2S_M,R2S_N,PIPE_D) // Allocate register tensors - auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(bEsD))); // (R2S,R2S_M,R2S_N) + auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(sD_epi))); Tensor tRS_rC = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) - Tensor tRS_rD = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) - // Vectorized fragment view for thread epilogue op - Tensor tRS_rAcc_frg = recast(tRS_rAcc); - Tensor tRS_rC_frg = recast(tRS_rC); - Tensor tRS_rD_frg = recast(tRS_rD); + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg = recast>(tRS_rAcc); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); - // Partition for smem to register copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_r2s); + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) - // Partition for smem to gmem copy (tSG_) + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor tSG_sD = conditional_return( - thrblk_s2g.partition_S(recast(bEsC)), // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - thrblk_s2g.partition_S(bEsD) ); // (S2G,S2G_M,S2G_N,PIPE) - Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + Tensor bSG_sD = conditional_return( + thrblk_s2g.partition_S(recast(sC_epi)), // (S2G,S2G_M,S2G_N,PIPE_C) + thrblk_s2g.partition_S(sD_epi) ); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly"); CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + EpilogueTile{}, + tiled_copy_C_atom, + thread_idx, + tRS_rC); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + // Thread synchronizer for previously issued waits or fences // to ensure visibility of smem reads/writes to threads or TMA unit auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 0); }; @@ -449,151 +509,142 @@ class CollectiveEpilogue< // Predication for TMA store (one warp issues TMA store) bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; - if (epilogue_op.is_source_needed()) { - // Wait for epilogue load to fill smem buffer with C - load_pipeline.consumer_wait(load_pipe_consumer_state); + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; } - // Delay issue of TMA store by 1 iteration to achieve better instruction pipelining - PipelineState store_pipe_producer_state_prev = store_pipe_producer_state; - int epi_m_prev = 0, epi_n_prev = 0; + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); // For each output tile CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(bEgD); ++epi_n) { + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(bEgD); ++epi_m) { + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { // The current tile in accumulator int mma_m = epi_m; int mma_n = (epi_n * epi_tile_n) / mma_tile_n; Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - // Elementwise operation with conversion - int r2s_v = epi_n * size(tRS_rD_frg); - if (epilogue_op.is_source_needed()) { - // Copy source tile to register from smem - if constexpr (cute::is_same_v) { - copy(tSR_sC(_,_,_,epi_m,epi_n), tSR_rC); - } - else { - copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC); - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rD_frg); ++i) { - tRS_rD_frg(i) = epilogue_op(tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i)); - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rD_frg); ++i) { - tRS_rD_frg(i) = epilogue_op(tRS_rAcc_frg_mn(r2s_v + i)); - } + // Wait for a smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); } + synchronize(); if constexpr (ReuseSmemC) { - // Issue the TMA store of the previous iteration - if (not (epi_m == 0 && epi_n == 0)) { - // Make sure smem writes are visible to TMA - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - - // Write the tile to gmem from smem with TMA - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); + // Let dma warp know smem buffer is consumed and empty after StagesD producer commits + if (issued_stores >= StagesD) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); } + ++load_pipe_consumer_state; } + } - // Copy output tile to smem from register - if constexpr (cute::is_same_v) { - copy(tRS_rD, tRS_sD(_,_,_,epi_m,epi_n)); - } - else { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n)); + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); } } - else { - // Issue the TMA store of the previous iteration - if (not (epi_m == 0 && epi_n == 0)) { - // Make sure smem writes are visible to TMA - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - // Write the tile to gmem from smem with TMA - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); - store_pipeline.producer_commit(store_pipe_producer_state_prev); - } - } + // First loop fusion callback entry point + cst_callbacks.step_begin(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); - // Wait for a smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; } - synchronize(); + ++load_wait_state; + } - // Copy tile to smem from register - if constexpr (cute::is_same_v) { - copy(tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } - else { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - } + // Vectorized fragment loop with visitor callback entry point + int r2s_v = epi_n * size(tRS_rD_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { + tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + + // Copy tile from register to smem + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - // Advance pipeline state - store_pipe_producer_state_prev = store_pipe_producer_state; - ++store_pipe_producer_state; + // Next loop fusion callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.step_next(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); } - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - } + // Last loop fusion callback entry point + cst_callbacks.step_end(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); - if constexpr (ReuseSmemC) { - // Fence and issue the TMA store of the last iteration - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); - } + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + } // for epi_m + } // for epi_n - // Arrive and advance pipeline state - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; + // Post-loop fusion callback entry point + cst_callbacks.end(); - // Wait for a smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } - // Let dma warp know smem buffer is consumed and empty - if (epilogue_op.is_source_needed()) { - load_pipeline.consumer_release(store_pipe_producer_state); - } - } - else { - // Fence and issue the TMA store of the last iteration - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); - store_pipeline.producer_commit(store_pipe_producer_state_prev); - } + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; - // Let dma warp know smem buffer is consumed and empty - if (epilogue_op.is_source_needed()) { - load_pipeline.consumer_release(load_pipe_consumer_state); + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD previously issued TMA stores + constexpr int release_stages = + cute::min(StagesD, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } } } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); } private: Params const& params; - ThreadEpilogueOp epilogue_op; + FusionCallbacks fusion_callbacks; + int issued_stores = 0; }; diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp index 4a7978b2fd..070b8fb231 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -1,40 +1,41 @@ /*************************************************************************************************** - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Functor performing pipelined epilogues with bias add and elementwise activation functions. + This collective is now DEPRECATED, will be removed in the next release. Use EVT instead. */ #pragma once -#include "cutlass/cutlass.h" -#include "cutlass/arch/barrier.h" -#include "cutlass/epilogue/dispatch_policy.hpp" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/epilogue/thread/scale_type.h" - -#include "cute/tensor.hpp" +#include "sm90_epilogue_tma_warpspecialized.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -47,13 +48,14 @@ namespace collective { template < int StagesC_, int StagesD_, + int FragmentSize_, class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N) class ElementC_, class StrideC_, class ElementD_, class StrideD_, - class ThreadEpilogueOp_, + class FusionCallbacks_, class CopyOpG2S_, class SmemLayoutAtomC_, class CopyOpS2R_, @@ -61,620 +63,89 @@ template < class SmemLayoutAtomD_, class CopyOpR2S_ > -class CollectiveEpilogue< - Sm90TmaWarpSpecializedBiasElementwise, - BlockTileShape_, - EpilogueTileShape_, - ElementC_, - StrideC_, - ElementD_, - StrideD_, - ThreadEpilogueOp_, - CopyOpG2S_, - SmemLayoutAtomC_, - CopyOpS2R_, - CopyOpS2G_, - SmemLayoutAtomD_, - CopyOpR2S_ +class Sm90EpilogueTmaWarpSpecializedBiasElementwise + : public CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_ > { -public: - // - // Type Aliases - // - using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; - using BlockTileShape = BlockTileShape_; - using EpilogueTileShape = EpilogueTileShape_; - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementScalar = ElementCompute; - using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; - using ElementT = typename ThreadEpilogueOp::ElementT; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementC = ElementC_; - using StrideC = StrideC_; - using ElementD = ElementD_; - using StrideD = StrideD_; - using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; - using BinaryOp = typename ThreadEpilogueOp::BinaryOp; - - using CopyOpG2S = CopyOpG2S_; - using SmemLayoutAtomC = SmemLayoutAtomC_; - using CopyOpS2R = CopyOpS2R_; - using CopyOpS2G = CopyOpS2G_; - using SmemLayoutAtomD = SmemLayoutAtomD_; - using CopyOpR2S = CopyOpR2S_; - - using GmemTiledCopyC = SM90_TMA_LOAD; - using GmemTiledCopyD = SM90_TMA_STORE; - - constexpr static bool StoreT = ThreadEpilogueOp::kStoreT; - constexpr static int kOutputAlignment = ThreadEpilogueOp::kCount; - static_assert(detail::IsThreadEpilogueOpWithBias::value, - "Epilogue dispatch policy Sm90TmaWarpSpecializedBiasElementwise requires the use of a thread-level epiogue that supports bias calculation"); - constexpr static bool iskThreadEpilogueOpWithBias = true; - using AlignmentType = typename uint_bit::value * kOutputAlignment>::type; - - static_assert(!is_layout::value && is_tuple::value, "EpilogueTileShape must be a cute::Shape"); - static_assert(rank(BlockTileShape{}) == 3, "BlockTileShape must be rank-3: [BLK_M,BLK_N,BLK_K]"); - static_assert(rank(EpilogueTileShape{}) == 2, "EpilogueTileShape must be rank-2: [EPI_TILE_M,EPI_TILE_N]"); - static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); - private: - using InternalElementC = cute::conditional_t,ElementD,ElementC>; // prevents void ref breakages - constexpr static int StagesC = StagesC_; - constexpr static int StagesD = StagesD_; - constexpr static bool is_source_supported = not cute::is_void_v; - static_assert((cute::is_void_v && not is_source_supported) || (not cute::is_void_v && is_source_supported), - "Inconsistent C type and Scale kind"); - - // internal optimization to reuse C shared memory for storing D - using SmemLayoutAtomBitsC = decltype(downcast::value>(SmemLayoutAtomC{})); - using SmemLayoutAtomBitsD = decltype(downcast::value>(SmemLayoutAtomD{})); - constexpr static bool ReuseSmemC = is_source_supported && - sizeof(InternalElementC) == sizeof(ElementD) && - StrideC{} == StrideD{} && - cute::is_same_v && - not StoreT; - + using Impl = + CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_ + >; public: - using SmemLayoutC = decltype(tile_to_shape( - SmemLayoutAtomC{}, - make_shape(size<0>(BlockTileShape{}), size<1>(BlockTileShape{}), Int{}), - cute::conditional_t(StrideC{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); - using SmemLayoutD = decltype(tile_to_shape( - SmemLayoutAtomD{}, - make_shape(size<0>(EpilogueTileShape{}), size<1>(EpilogueTileShape{}), Int{}), - cute::conditional_t(StrideD{}) == 1, Step<_2,_1,_3>, Step<_1,_2,_3>>{} )); - - // TMA pipeline for loading C - using LoadPipeline = cutlass::PipelineTransactionAsync; - using LoadPipelineState = cutlass::PipelineState; - constexpr static uint32_t TmaTransactionBytes = - size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof(InternalElementC)); + using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; + using ElementCompute = typename Impl::ThreadEpilogueOp::ElementCompute; + using ElementBias = typename Impl::ThreadEpilogueOp::ElementBias; + using ElementT = typename Impl::ThreadEpilogueOp::ElementAux; - // TMA pipeline for storing D and T. ReuseSmemC cannot be set to true if StoreT is enabled. - using StorePipeline = cutlass::PipelineTmaStore; - using StorePipelineState = cutlass::PipelineState; - - struct SharedStorage { - struct TensorStorage : aligned_struct<128> { - cute::conditional_t, - array_aligned> smem_C; - alignas(128) cute::conditional_t, - array_aligned> smem_D; - alignas(128) cute::conditional_t, - array_aligned> smem_T; - } tensors; - - using PipelineStorage = typename LoadPipeline::SharedStorage; - PipelineStorage pipeline; - }; - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; + // Constructor inheritance + using Impl::Impl; // Host side epilogue arguments - struct Arguments { - typename ThreadEpilogueOp::Params thread; - ElementC const* ptr_C; - StrideC dC; - ElementD const* ptr_D; - StrideD dD; - ElementBias const* ptr_Bias = nullptr; - ElementT const* ptr_T = nullptr; - }; - - // Device side epilgoue params - struct Params { - using TMA_C = decltype(make_tma_copy( - CopyOpG2S{}, - make_tensor(static_cast(nullptr), - repeat_like(StrideC{}, int32_t(0)), StrideC{}), - SmemLayoutC{}(_,_,0))); - using TMA_D = decltype(make_tma_copy( - CopyOpS2G{}, - make_tensor(static_cast(nullptr), - repeat_like(StrideD{}, int32_t(0)), StrideD{}), - SmemLayoutD{}(_,_,0))); - using TMA_T = decltype(make_tma_copy( - CopyOpS2G{}, - make_tensor(static_cast(nullptr), - repeat_like(StrideD{}, int32_t(0)), StrideD{}), - SmemLayoutD{}(_,_,0))); - typename ThreadEpilogueOp::Params thread{}; - TMA_C tma_load_c; - TMA_D tma_store_d; - TMA_T tma_store_t; + struct [[deprecated("use Sm90TmaWarpSpecialized Arguments instead")]] + Arguments { + struct ThreadArgs { + ElementCompute alpha; + ElementCompute beta; + ElementCompute const *alpha_ptr; + ElementCompute const *beta_ptr; + } thread; + ElementC_ const* ptr_C; + StrideC_ dC; + ElementD_* ptr_D; + StrideD_ dD; ElementBias const* ptr_Bias = nullptr; - }; - - // - // Methods - // - - template - static constexpr Params - to_underlying_arguments( - ProblemShape const& problem_shape, - Arguments const& args, - [[maybe_unused]] void* workspace) { - // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - typename Params::TMA_C tma_load_c = [&]() { - if constexpr (not cute::is_void_v) { - Tensor tensor_c = make_tensor(static_cast(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); - return make_tma_copy( - CopyOpG2S{}, - tensor_c, - SmemLayoutC{}(_,_,0)); - } - else { - return typename Params::TMA_C{}; - } - }(); - - Tensor tensor_d = make_tensor(args.ptr_D, make_layout(make_shape(M,N,L), args.dD)); - typename Params::TMA_D tma_store_d = make_tma_copy( - CopyOpS2G{}, - tensor_d, - SmemLayoutD{}(_,_,0)); - - typename Params::TMA_T tma_store_t = [&]() { - if constexpr (StoreT) { - Tensor tensor_t = make_tensor(args.ptr_T, make_layout(make_shape(M,N,L), args.dD)); - return make_tma_copy( - CopyOpS2G{}, - tensor_t, - SmemLayoutD{}(_,_,0)); + ElementT* ptr_T = nullptr; + + CUTLASS_HOST_DEVICE + operator typename Impl::Arguments() const { + typename Impl::Arguments arguments; + arguments.thread.alpha = thread.alpha; + arguments.thread.beta = thread.beta; + arguments.thread.alpha_ptr = thread.alpha_ptr; + arguments.thread.beta_ptr = thread.beta_ptr; + if constexpr (not cute::is_void_v) { + arguments.thread.bias_ptr = ptr_Bias; } - else { - return typename Params::TMA_T{}; + if constexpr (not cute::is_void_v) { + arguments.thread.aux_ptr = ptr_T; + arguments.thread.dAux = dD; } - }(); - - return { - args.thread, - tma_load_c, - tma_store_d, - tma_store_t, - args.ptr_Bias - }; - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { - // Compute number of C subtiles (currently always one) - constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutC{}); - constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutC{}); + arguments.ptr_C = ptr_C; + arguments.dC = dC; + arguments.ptr_D = ptr_D; + arguments.dD = dD; - return epi_m * epi_n; - } - - template - CUTLASS_HOST_DEVICE - static constexpr int - get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { - if constexpr (ReuseSmemC) { - return get_load_pipe_increment(tile_shape_MNK); + return arguments; } + }; - // Compute number of D subtiles - constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(SmemLayoutD{}); - constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(SmemLayoutD{}); - - return epi_m * epi_n; - } - - CUTLASS_HOST_DEVICE - CollectiveEpilogue(Params const& params_) - : params(params_), epilogue_op(params_.thread) { } - - CUTLASS_DEVICE - bool - is_source_needed() { - return is_source_supported && epilogue_op.is_source_needed(); - } - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& epilogue_params) { - cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); - cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); - if constexpr (StoreT) { - cute::prefetch_tma_descriptor(epilogue_params.tma_store_t.get_tma_descriptor()); - } - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledMma - > - CUTLASS_DEVICE void - load( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - TiledMma tiled_mma, - [[maybe_unused]] int thread_idx, - TensorStorage& shared_tensors) { - using namespace cute; - using X = Underscore; - - int warp_idx = canonical_warp_idx(); - int warp_idx_in_warp_group = warp_idx % 4; - int lane_predicate = cute::elect_one_sync(); - - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - - // Represent the full source tensor - Tensor mC_mnl = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) - Tensor gC_mnl = local_tile(mC_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (TILE_M,TILE_N,m,n,l) - // Slice to get the gmem tile of C (gC) this CTA is currently responsible for - Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) - // Get the corresponding smem tile of C (sC) - Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), SmemLayoutC{}); // (TILE_M,TILE_N,PIPE) - - // Prepare the thread(b)lock (G)mem to (S)mem TMA copy (bGS_) - ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); - Tensor bGS_gC = thrblk_g2s.partition_S(gC); // (TMA,TMA_M,TMA_N) - Tensor bGS_sC = thrblk_g2s.partition_D(sC); // (TMA,TMA_M,TMA_N,PIPE) - - auto* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); - uint16_t mcast_mask = 0; - - // Execute the TMA load for C - if (warp_idx_in_warp_group == 0 and lane_predicate) { - load_pipeline.producer_acquire(load_pipe_producer_state); - copy(params.tma_load_c.with(*tma_barrier, mcast_mask), bGS_gC, bGS_sC(_,_,_,load_pipe_producer_state.index())); - load_pipeline.producer_commit(load_pipe_producer_state); - } - } - - CUTLASS_DEVICE void - load_tail( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_producer_state) { - int warp_idx = canonical_warp_idx(); - int warp_idx_in_warp_group = warp_idx % 4; - int lane_predicate = cute::elect_one_sync(); - - if (warp_idx_in_warp_group == 0 and lane_predicate) { - load_pipeline.producer_tail(load_pipe_producer_state); - } - } - - template< - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class AccEngine, class AccLayout, - class TiledMma - > - CUTLASS_DEVICE void - store( - LoadPipeline load_pipeline, - LoadPipelineState load_pipe_consumer_state, - StorePipeline store_pipeline, - StorePipelineState store_pipe_producer_state, - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_MNK, - TileCoordMNKL tile_coord_mnkl, - cute::Tensor accumulators, - TiledMma tiled_mma, - int thread_idx, - TensorStorage& shared_tensors) { - using namespace cute; - using X = Underscore; - - static_assert(is_rmem::value, "Accumulator must be RF resident."); - static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "TileShapeMNK must be static"); - static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); - static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - auto mma_tile_m = size<0>(typename TiledMma::TiledShape_MNK{}); - auto mma_tile_n = size<1>(typename TiledMma::TiledShape_MNK{}); - auto epi_tile_m = size<0>(EpilogueTileShape{}); - auto epi_tile_n = size<1>(EpilogueTileShape{}); - - // Represent the full output tensor - Tensor mD_mnl = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) - Tensor gD_mnl = local_tile(mD_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1, _1, X>{}); // (TILE_M,TILE_N,m,n,l) - Tensor mT_mnl = params.tma_store_t.get_tma_tensor(make_shape(M,N,L)); // (m,n,l) - Tensor gT_mnl = local_tile(mT_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1, _1, X>{}); // (TILE_M,TILE_N,m,n,l) - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), Stride<_1, _0, _0>{}); // (m,n,l) - Tensor gBias_mnl = local_tile(mBias_mnl, tile_shape_MNK, make_coord(_,_,_), Step<_1,_1,X>{}); // (TILE_M,TILE_N,m,n,l) - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; - Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) - Tensor gT = gT_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) - Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (TILE_M,TILE_N) - - // Construct the smem tensors for source (sC) and output (sD, sT) - Tensor sC = make_tensor(make_smem_ptr(shared_tensors.smem_C.data()), // (TILE_M,TILE_N) - SmemLayoutC{})(_,_,load_pipe_consumer_state.index()); - Tensor bEsD = make_tensor(make_smem_ptr(shared_tensors.smem_D.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE) - SmemLayoutD{}); - Tensor bEsT = make_tensor(make_smem_ptr(shared_tensors.smem_T.data()), // (EPI_TILE_M,EPI_TILE_N,PIPE) - SmemLayoutD{}); - - // Tile thread(b)lock tensors by (E)pilogue output tile shape (bE) - Tensor bEsC = local_tile(sC, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor bEgD = local_tile(gD, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor bEgT = local_tile(gT, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor bEgBias = local_tile(gBias, EpilogueTileShape{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Partition for register to smem copy (tRS_) - using CopyAtomR2S = cute::conditional_t, - Copy_Atom>,ElementD>, - Copy_Atom>; - TiledCopy tiled_r2s = make_tiled_copy_C_atom(CopyAtomR2S{}, tiled_mma); - ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = conditional_return( - thread_r2s.partition_D(recast(bEsC)), // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - thread_r2s.partition_D(bEsD) ); // (R2S,R2S_M,R2S_N,PIPE) - Tensor tRS_sT = thread_r2s.partition_D(bEsT); // (R2S,R2S_M,R2S_N,PIPE) - - // Allocate register tensors - auto tRS_rD_shape = take<0,3>(shape(thread_r2s.partition_S(bEsD))); // (R2S,R2S_M,R2S_N) - Tensor tRS_rC = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) - Tensor tRS_rD = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) - Tensor tRS_rT = make_tensor(tRS_rD_shape); // (R2S,R2S_M,R2S_N) - - Tensor tRS_gBias = thread_r2s.partition_S(bEgBias); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) - Tensor tRS_rBias = make_tensor(take<0,3>(shape(tRS_gBias))); // (R2S,R2S_M,R2S_N) - - // Vectorized fragment view for thread epilogue op - Tensor tRS_rAcc_frg = recast(tRS_rAcc); - Tensor tRS_rC_frg = recast(tRS_rC); - Tensor tRS_rD_frg = recast(tRS_rD); - Tensor tRS_rT_frg = recast(tRS_rT); - Tensor tRS_rBias_frg = recast(tRS_rBias); - - // thread::LinearCombinationBiasElementwise expects that the bias passed in is of - // type ElementCompute. Therefore, conversion from type ElementBias to ElementCompute - // is needed before calling the thread-level epilogue. - cutlass::NumericArrayConverter bias_converter; - - // Partition for smem to register copy (tSR_) - TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_r2s); - ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); - Tensor tSR_sC = thread_s2r.partition_S(bEsC); // (S2R,S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) - - // Partition for smem to gmem copy (tSG_) - ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); - Tensor tSG_sD = conditional_return( - thrblk_s2g.partition_S(recast(bEsC)), // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - thrblk_s2g.partition_S(bEsD) ); // (S2G,S2G_M,S2G_N,PIPE) - Tensor tSG_gD = thrblk_s2g.partition_D(bEgD); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - ThrCopy thrblk_s2g_t = params.tma_store_t.get_slice(Int<0>{}); - Tensor tSG_sT = thrblk_s2g_t.partition_S(bEsT); // (S2G,S2G_M,S2G_N,PIPE) - Tensor tSG_gT = thrblk_s2g_t.partition_D(bEgT); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) - - CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % ThreadEpilogueOp::kCount == 0, "ThreadEpilogueOp does not vectorize properly"); - CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - - // Thread synchronizer for previously issued waits or fences - // to ensure visibility of smem reads/writes to threads or TMA unit - auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 0); }; - - // Predication for TMA store (one warp issues TMA store) - bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; - - if (is_source_supported && epilogue_op.is_source_needed()) { - // Wait for epilogue load to fill smem buffer with C - load_pipeline.consumer_wait(load_pipe_consumer_state); - } - - // Delay issue of TMA store by 1 iteration to achieve better instruction pipelining - PipelineState store_pipe_producer_state_prev = store_pipe_producer_state; - int epi_m_prev = 0, epi_n_prev = 0; - - // For each output tile - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(bEgD); ++epi_n) { - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(bEgD); ++epi_m) { - // The current tile in accumulator - int mma_m = epi_m; - int mma_n = (epi_n * epi_tile_n) / mma_tile_n; - Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - - // Copy bias to registers from gmem - copy(tRS_gBias(_,_,_,epi_m,epi_n), tRS_rBias); - - // Elementwise operation with conversion - int r2s_v = epi_n * size(tRS_rD_frg); - if (is_source_supported && epilogue_op.is_source_needed()) { - // Copy source tile to registers from smem - if constexpr (cute::is_same_v) { - copy(tSR_sC(_,_,_,epi_m,epi_n), tSR_rC); - } - else { - copy(tiled_s2r, tSR_sC(_,_,_,epi_m,epi_n), tSR_rC); - } - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rD_frg); ++i) { - typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i)); - epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), tRS_rC_frg(i), converted_bias); - } - } - else { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tRS_rD_frg); ++i) { - typename ThreadEpilogueOp::FragmentCompute converted_bias = bias_converter(tRS_rBias_frg(i)); - epilogue_op(tRS_rD_frg(i), tRS_rT_frg(i), tRS_rAcc_frg_mn(r2s_v + i), converted_bias); - } - } - - if constexpr (ReuseSmemC) { - // If ReuseSmemC is true, StoreT must be false. Therefore, we do not perform copies for T in this block. - - // Issue the TMA store of the previous iteration - if (not (epi_m == 0 && epi_n == 0)) { - // Make sure smem writes are visible to TMA - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - - // Write the tile to gmem from smem with TMA - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); - } - } - - // Copy output tile to smem from register - if constexpr (cute::is_same_v) { - copy(tRS_rD, tRS_sD(_,_,_,epi_m,epi_n)); - } - else { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,epi_m,epi_n)); - } - } - else { - // Issue the TMA store of the previous iteration - if (not (epi_m == 0 && epi_n == 0)) { - // Make sure smem writes are visible to TMA - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - - // Write the tile to gmem from smem with TMA - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); - if constexpr (StoreT) { - copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state_prev.index()), tSG_gT(_,_,_,epi_m_prev,epi_n_prev)); - } - store_pipeline.producer_commit(store_pipe_producer_state_prev); - } - } - - // Wait for a smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - // Copy tile to smem from register - if constexpr (cute::is_same_v) { - copy(tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - if constexpr (StoreT) { - copy(tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index())); - } - } - else { - copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - if constexpr (StoreT) { - copy(tiled_r2s, tRS_rT, tRS_sT(_,_,_,store_pipe_producer_state.index())); - } - } - - // Advance pipeline state - store_pipe_producer_state_prev = store_pipe_producer_state; - ++store_pipe_producer_state; - } - - epi_m_prev = epi_m; - epi_n_prev = epi_n; - } - } - - if constexpr (ReuseSmemC) { - // If ReuseSmemC is true, StoreT must be false. Therefore, we do not perform copies for T in this block. - - // Fence and issue the TMA store of the last iteration - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,epi_m_prev,epi_n_prev), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); - } - - // Arrive and advance pipeline state - if (issue_tma_store) { - store_pipeline.producer_commit(store_pipe_producer_state); - } - ++store_pipe_producer_state; - - // Wait for a smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - // Let dma warp know smem buffer is consumed and empty - if (is_source_supported && epilogue_op.is_source_needed()) { - load_pipeline.consumer_release(store_pipe_producer_state); - } - } - else { - // Fence and issue the TMA store of the last iteration - cutlass::arch::fence_view_async_shared(); - synchronize(); // ensure all threads have issued their async fence - if (issue_tma_store) { - copy(params.tma_store_d, tSG_sD(_,_,_,store_pipe_producer_state_prev.index()), tSG_gD(_,_,_,epi_m_prev,epi_n_prev)); - if (StoreT) { - copy(params.tma_store_t, tSG_sT(_,_,_,store_pipe_producer_state_prev.index()), tSG_gT(_,_,_,epi_m_prev,epi_n_prev)); - } - store_pipeline.producer_commit(store_pipe_producer_state_prev); - } - - // Let dma warp know smem buffer is consumed and empty - if (epilogue_op.is_source_needed()) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - } - } - -private: - Params const& params; - ThreadEpilogueOp epilogue_op; }; diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp index c3fb61eff3..639115c431 100644 --- a/include/cutlass/epilogue/dispatch_policy.hpp +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -1,24 +1,30 @@ /*************************************************************************************************** - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ @@ -33,20 +39,25 @@ namespace cutlass::epilogue { ////////////////////////////////////////////////////////////////////////////// -// Epilogue schedule types that can be used for categorical dispatch +////////////////////////////////////////////////////////////////////////////// +// +// Builder Epilogue Schedules +// +////////////////////////////////////////////////////////////////////////////// + struct NoSmemWarpSpecialized {}; struct TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperative {}; - +// DEPRECATED schedules, will be removed in next release struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; - template < template class ActivationFunctor_, thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest > -struct TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { template using ActivationFunctor = ActivationFunctor_; static constexpr thread::ScaleType::Kind Scale = Scale_; @@ -58,7 +69,8 @@ template < thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest > -struct TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { template using ActivationFunctor = ActivationFunctor_; static constexpr thread::ScaleType::Kind Scale = Scale_; @@ -75,7 +87,8 @@ template < bool StoreT_, class ElementBias_ > -struct TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { template using ActivationFunctor = ActivationFunctor_; using ElementT = ElementT_; @@ -94,7 +107,8 @@ template < bool StoreT_, class ElementBias_ > -struct TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { template using ActivationFunctor = ActivationFunctor_; @@ -107,28 +121,36 @@ struct TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedC using ElementBias = ElementBias_; }; +////////////////////////////////////////////////////////////////////////////// // -// Collective Epilogue Policies +// Collective Dispatch Policies // +////////////////////////////////////////////////////////////////////////////// template< int StagesC_, int StagesD_, - bool DisableSmemReuseC_ + int FragmentSize_, + bool ReuseSmemC_ > struct Sm90TmaWarpSpecialized { constexpr static int StagesC = StagesC_; constexpr static int StagesD = StagesD_; - constexpr static bool DisableSmemReuseC = DisableSmemReuseC_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; }; + +// DEPRECATED policies, will be removed in next release template< int StagesC_, - int StagesD_ + int StagesD_, + int FragmentSize_ = 2 > struct Sm90TmaWarpSpecializedBiasElementwise { constexpr static int StagesC = StagesC_; constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; }; ////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/callbacks.hpp b/include/cutlass/epilogue/fusion/callbacks.hpp new file mode 100644 index 0000000000..e9b8f65194 --- /dev/null +++ b/include/cutlass/epilogue/fusion/callbacks.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Dispatch interface for epilogue fusion callbacks +// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. +// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, +// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. +template < + class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm + class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination + class CtaTile_MNK, // computed tile per CTA + class EpilogueTile_MN, // epilogue subtile size + class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) +> +struct FusionCallbacks { + static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); +}; + +// Metadata helper to handle custom EVTs or other non-FusionCallbacks types +template +struct FusionCallbacksTraits { + using DispatchPolicy = void; + using Operation = T; + using CtaTile_MNK = void; + using EpilogueTile_MN = void; +}; + +template < + class DispatchPolicy_, + class Operation_, + class CtaTile_MNK_, + class EpilogueTile_MN_, + class... Args +> +struct FusionCallbacksTraits< + FusionCallbacks +> { + using DispatchPolicy = DispatchPolicy_; + using Operation = Operation_; + using CtaTile_MNK = CtaTile_MNK_; + using EpilogueTile_MN = EpilogueTile_MN_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp new file mode 100644 index 0000000000..14db464397 --- /dev/null +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Fusion Operations +// Template args must not be implementation dependent +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct FusionOperation { + // metadata types/queries that can be overrided + using ElementOutput = void; + using ElementCompute = void; + static constexpr bool IsSourceSupported = false; + + using ElementScalar = void; + static constexpr int AlignmentScalar = 0; + static constexpr bool IsScaleFactorSupported = false; + static constexpr bool IsPerRowScaleSupported = false; + + using ElementBias = void; + static constexpr int AlignmentBias = 0; + static constexpr bool IsPerRowBiasSupported = false; + template using ActivationFn = void; + static constexpr bool IsEltActSupported = false; + + using ElementAux = void; + using GmemLayoutTagAux = void; + static constexpr int AlignmentAux = 0; + static constexpr bool IsAuxOutSupported = false; + using ElementAmax = void; + static constexpr bool IsAbsMaxSupported = false; +}; + +// D = alpha * acc +template< + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAcc : FusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = 1; + static constexpr auto RoundStyle = RoundStyle_; +}; + +// D = alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinearCombination + : ScaledAcc { + static constexpr bool IsSourceSupported = true; +}; + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombEltAct + : LinearCombination { + template + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = alpha * acc + beta * C + per-row bias +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltAct + : LinCombPerRowBias { + template + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +// aux = alpha * acc + beta * C + per-row bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltActAux + : LinCombPerRowBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / sizeof_bits_v, + int AlignmentScalar_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerRowLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerRowScaleSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltActAmaxAux + : ScaledLinCombPerRowBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp new file mode 100644 index 0000000000..b2290a40fb --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -0,0 +1,973 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Sm90EVT = Sm90TreeVisitor; + +// D = alpha * acc +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledAcc, + CtaTileShapeMNK, + EpilogueTile +> : Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch + > { + using Impl = + Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch + >; + using Operation = fusion::ScaledAcc; + + struct Arguments { + // Give a name and flat ordering to the fusion callback args + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +template< + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombination = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombination { + + using Impl = Sm90LinearCombination; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltAct { + + using Impl = Sm90LinCombEltAct; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,_0>, AlignmentBias> // bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBias, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle> { + using Impl = Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>; + using Operation = fusion::LinCombPerRowBias< + ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle>; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltAct = + Sm90EVT, + Sm90LinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +// Aux = alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerRowBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using StrideAux = cutlass::gemm::TagToStrideC_t; + using Impl = + Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + {} // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = per-row alpha * acc + per-row beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,_0>, AlignmentBias> // bias + > + >; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBiasEltAct = + Sm90EVT, + Sm90PerRowLinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// We only apply the scaling factor if output is fp8 +template +struct ScaleOutOp { template using Op = cutlass::first; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; + +template +using amax = cutlass::maximum_absolute_value_reduction; // propogate nans + +}; // end namespace detail + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,_0>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltAct = + Sm90EVT::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerRowBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementBias const* bias_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{scale_c, beta}, + {scale_c_ptr, beta_ptr} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{scale_a, scale_b, alpha}, + {scale_a_ptr, scale_b_ptr, alpha_ptr} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT::Op, ElementCompute, ElementCompute, RoundStyle>, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementAmax, + class ElementBias, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using StrideAux = cutlass::gemm::TagToStrideC_t; + using Impl = + Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + + ElementBias const* bias_ptr = nullptr; + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{scale_c, beta}, + {scale_c_ptr, beta_ptr} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{scale_a, scale_b, alpha}, + {scale_a_ptr, scale_b_ptr, alpha_ptr} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (cute::is_same_v || + cute::is_same_v) { + amax_D_ptr_ = amax_D_ptr; + } + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + {} // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + // Only compute amax_aux if aux is fp8 + ElementAmax* amax_aux_ptr_ = nullptr; + if constexpr (cute::is_same_v || + cute::is_same_v) { + amax_aux_ptr_ = amax_aux_ptr; + } + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr_} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp new file mode 100644 index 0000000000..9d3dabd799 --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree compute operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// N-nary Elementwise Compute Operation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + template class ComputeFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class = void +> +struct Sm90Compute : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) { + return transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + return convert_input(frg_input); + }, + [&] (auto&&... cvt_frg_inputs) { + using ComputeOutput = ComputeFn>; + using ConvertOutput = NumericArrayConverter; + ComputeOutput compute_output{}; + ConvertOutput convert_output{}; + + return convert_output(compute_output(cvt_frg_inputs...)); + } + ); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + return ConsumerStoreCallbacks(); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Performance Optimized Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// beta * C + Z +template < + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class ElementScalar, + class StrideScalar, + int ScalarCount, + template class ScalarReduceFn, + class InputAddOp // Z +> +struct Sm90TreeVisitor< + Sm90Compute, + Sm90ScalarBroadcast, + Sm90SrcFetch, + InputAddOp +> : Sm90VisitorImpl< + Sm90ScalarBroadcast, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + Sm90ScalarBroadcast, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + >; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + auto const& bcast_op = get<0>(Impl::ops); + auto const& added_op = get<2>(Impl::ops); + return not (bcast_op.params_ptr->dScalar == Stride<_0,_0,_0>{} && not is_C_load_needed()) || + added_op.is_producer_load_needed(); + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + auto const& bcast_op = get<0>(Impl::ops); + auto const& added_op = get<2>(Impl::ops); + return bcast_op.scalar != 0 || added_op.is_C_load_needed(); + } + + using Impl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(bool is_C_load_needed, CallbacksImpl&& impl) + : is_C_load_needed(is_C_load_needed), CallbacksImpl(cute::forward(impl)) { } + + bool is_C_load_needed; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_added = get<2>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementZ = typename decltype(frg_added)::Element; + using ConvertZ = NumericArrayConverter; + using ConvertI = NumericArrayConverter; + ConvertZ convert_Z{}; + ConvertI convert_I{}; + + Array frg_I = convert_Z(frg_added); + + if (is_C_load_needed) { + Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementX = typename decltype(frg_scalar)::Element; + using ElementY = typename decltype(frg_source)::Element; + using ConvertX = NumericArrayConverter; + using ConvertY = NumericArrayConverter; + using ComputeI = multiply_add>; + ConvertX convert_X{}; + ConvertY convert_Y{}; + ComputeI compute_I{}; + + frg_I = compute_I(convert_X(frg_scalar), convert_Y(frg_source), frg_I); + } + + return convert_I(frg_I); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + return ConsumerStoreCallbacks( + is_C_load_needed(), + Impl::get_consumer_store_callbacks( + problem_shape_mnkl, + tile_shape_mnk, + tile_coord_mnkl, + epi_tile, + tiled_copy, + thread_idx, + tCrC + ) + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000..28559027a7 --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -0,0 +1,869 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree load operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Fetch Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns accumulator +struct Sm90AccFetch : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return frg_acc; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + return ConsumerStoreCallbacks{}; + } +}; + +// Split tree visitor fetches intermediate results from temporary accumulators +using Sm90SplitTreeFetch = Sm90AccFetch; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns C +struct Sm90SrcFetch : Sm90VisitorImpl<> { + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return true; + } + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(SrcTensor const& tCrC) + : tCrC(tCrC) {} + + // make this a pointer if we need default ctor for generic tuple of visitors + SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N) + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tCrC)(epi_v); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + return ConsumerStoreCallbacks(tCrC); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90AuxLoad { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + struct SharedStorage { + alignas(128) array_aligned smem_aux; + }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + SM90_TMA_LOAD{}, + make_tensor(static_cast(nullptr), repeat_like(StrideMNL{}, int32_t(0)), StrideMNL{}), + SmemLayoutTma{})); + TMA_Aux tma_load_aux; + Element null_default = Element(0); + bool use_default = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + Tensor tensor_aux = make_tensor(args.ptr_aux, make_layout(make_shape(M,N,L), args.dAux)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(SM90_TMA_LOAD{}, tensor_aux, SmemLayoutTma{}); + + bool use_default = false; + if constexpr (EnableNullptr) { + use_default = args.ptr_aux == nullptr; + } + + return Params{tma_load_aux, args.null_default, use_default}; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage& shared_storage) + : params_ptr(¶ms), + smem_aux(shared_storage.smem_aux.data()) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& bGS_gAux, STensor&& bGS_sAux, Params const* params_ptr) + : bGS_gAux(cute::forward(bGS_gAux)), + bGS_sAux(cute::forward(bGS_sAux)), + params_ptr(params_ptr) {} + + GTensor bGS_gAux; // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + STensor bGS_sAux; // (TMA,TMA_M,TMA_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + return; + } + } + + if (issue_tma_load) { + // Increment the expected transaction bytes of the current stage's mbarrier by the subtile's byte-size + constexpr uint32_t copy_bytes = size(take<0,2>(SmemLayout{})) * sizeof_bytes_v; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA load + constexpr uint16_t mcast_mask = 0; + int load_pipe_index = load_iteration % Stages; + copy(params_ptr->tma_load_aux.with(*full_mbarrier_ptr, mcast_mask), + bGS_gAux(_,_,_,epi_m,epi_n), bGS_sAux(_,_,_,load_pipe_index)); + } + } + }; + + template < + class TileShapeMNK + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + + auto [M, N, K, L] = problem_shape_mnkl; + Tensor mAux = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gAux = sm90_tensor_to_cta_tile(mAux, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + + Tensor gAux_epi = local_tile(gAux, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + + ThrCopy thrblk_g2s = params_ptr->tma_load_aux.get_slice(_0{}); + Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + + return ProducerLoadCallbacks( + cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux, TiledS2R tiled_s2r, STensorS2R&& tSR_sAux, Params const* params_ptr) + : tC_rAux(cute::forward(tC_rAux)), + tiled_s2r(tiled_s2r), + tSR_sAux(cute::forward(tSR_sAux)), + params_ptr(params_ptr) { } + + TiledS2R tiled_s2r; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorS2R tSR_sAux; // (S2R,S2R_M,S2R_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + step_begin(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + + using RLayoutS2R = decltype(cute::layout(TiledS2R{}.get_slice(0).retile_S(RTensor{}))); + Tensor tSR_rAux = make_tensor(tC_rAux.data(), RLayoutS2R{}); // (S2R,S2R_M,S2R_N) + + int load_pipe_index = load_iteration % Stages; + copy(tiled_s2r, tSR_sAux(_,_,_,load_pipe_index), tSR_rAux); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + + return tC_rAux_frg(epi_v); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + auto [M, N, K, L] = problem_shape_mnkl; + Tensor mAux = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mAux, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + auto tiled_s2r = conditional_return( + make_tiled_copy_S(Copy_Atom{}, tiled_copy), + make_tiled_copy_D(Copy_Atom{}, tiled_copy) + ); + Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + auto tSR_sAux = tiled_s2r.get_slice(thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) + + + return ConsumerStoreCallbacks(cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Broadcast Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar broadcast +template< + class Element, + class StrideMNL = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcast { + static_assert( + (cute::is_same_v>) || // scalar broadcast, e.g. alpha + (cute::is_same_v>)); // batched scalar broadcast, e.g. per-batch alpha + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + StrideMNL dScalar = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast(Params const& params, SharedStorage& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if constexpr (cute::is_same_v>) { + update_scalar(); + } + } + + Element scalar; + Params const* params_ptr; + + template < + class TileShapeMNK, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + // Get the scalar for batched broadcast + if constexpr (cute::is_same_v>) { + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + update_scalar(l_coord); + } + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + // Get the scalar for batched broadcast + if constexpr (cute::is_same_v>) { + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + update_scalar(l_coord); + } + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + int l_offset = l_coord * size<2>(params_ptr->dScalar); + + if (params_ptr->scalar_ptrs[0] != nullptr) { + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } else { + // batch stride is ignored for nullptr fallback + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + if (params_ptr->scalar_ptrs[i] != nullptr) { + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][l_offset]); + } else { + // batch stride is ignored for nullptr fallback + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Row vector broadcast +template< + // Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least + // ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90RowBroadcast { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // row vector broadcast, e.g. per-col alpha/bias + (cute::is_same_v>)); // batched row vector broadcast + + // Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem + struct SharedStorage { + array_aligned(CtaTileShapeMNK{}) * Stages> smem_row; + }; + + struct Arguments { + Element const* ptr_row = nullptr; + Element null_default = Element(0); + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast(Params const& params, SharedStorage& shared_storage) + : params(params), + smem_row(shared_storage.smem_row.data()) { } + + Params params; + Element* smem_row; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params) + : gRow(cute::forward(gRow)), + sRow(cute::forward(sRow)), + params(params) {} + + GTensor gRow; // (CTA_M,CTA_N) + STensor sRow; // (CTA_M,CTA_N,PIPE) + Params const& params; + + CUTLASS_DEVICE void + begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return; + } + } + + if (issue_tma_load) { + // Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size + constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bytes_v; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA bulk copy + auto bulk_copy = Copy_Atom{}.with(*full_mbarrier_ptr); + // Filter so we don't issue redundant copies over stride-0 modes + int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; + copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index))); + } + } + }; + + template < + class TileShapeMNK, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + + auto [M, N, K, L] = problem_shape_mnkl; + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor gRow = sm90_tensor_to_cta_tile(mRow, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + + constexpr int EpiTiles = size(shape_div(take<0,2>(tile_shape_mnk), epi_tile)); + return ProducerLoadCallbacks( + cute::move(gRow), cute::move(sRow), params); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params) + : tCrRow(cute::forward(tCrRow)), + tCsRow(cute::forward(tCsRow)), + params(params) {} + + RTensor tCrRow; // (CPY,CPY_M,CPY_N) + STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + Params const& params; + + CUTLASS_DEVICE void + step_begin(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + fill(tCrRow, params.null_default); + return; + } + } + + if (epi_m == 0) { // Assumes M-major subtile loop + // Filter so we don't issue redundant copies over stride-0 modes + int bcast_pipe_index = (load_iteration / EpiTiles) % Stages; + copy(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow)); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tCrRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), + make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); + Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) + sRow, epi_tile, tiled_copy, thread_idx); + Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) + + constexpr int EpiTiles = size(shape_div(take<0,2>(tile_shape_mnk), epi_tile)); + return ConsumerStoreCallbacks( + cute::move(tCrRow), cute::move(tCsRow), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class Element, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90ColBroadcast { + static_assert(Stages == 0, "Column broadcast doesn't support smem usage yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector broadcast, e.g. per-row alpha/bias + (cute::is_same_v>)); // batched col vector broadcast, e.g. batched per-row bias + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + struct Arguments { + Element const* ptr_col = nullptr; + Element null_default = Element(0); + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast(Params const& params, SharedStorage& shared_storage) + : params(params) { } + + Params params; + + template < + class TileShapeMNK, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params) + : tCgCol(cute::forward(tCgCol)), + tCrCol(cute::forward(tCrCol)), + params(params) {} + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + + CUTLASS_DEVICE void + begin() { + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + fill(tCrCol, params.null_default); + return; + } + } + + // Filter so we don't issue redundant copies over stride-0 modes + copy(filter(tCgCol), filter(tCrCol)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + auto [M, N, K, L] = problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks( + cute::move(tCgCol), cute::move(tCrCol), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix broadcast +// Only need to redefine this if we can multicast across cluster L +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +using Sm90MatrixBroadcast + = Sm90AuxLoad; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp new file mode 100644 index 0000000000..7da6d09c49 --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -0,0 +1,865 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + FloatRoundStyle RoundStyle, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90AuxStore { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + struct SharedStorage { + alignas(128) array_aligned smem_aux; + }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(static_cast(nullptr), repeat_like(StrideMNL{}, int32_t(0)), StrideMNL{}), + SmemLayoutTma{})); + TMA_Aux tma_store_aux; + bool is_nullptr = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + bool is_nullptr = false; + if constexpr (EnableNullptr) { + is_nullptr = args.ptr_aux == nullptr; + } + + typename Params::TMA_Aux tma_store_aux; + if (not is_nullptr) { + Tensor tensor_aux = make_tensor(args.ptr_aux, make_layout(make_shape(M,N,L), args.dAux)); + tma_store_aux = make_tma_copy(SM90_TMA_STORE{}, tensor_aux, SmemLayoutTma{}); + } + + return {tma_store_aux, is_nullptr}; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage& shared_storage) + : params_ptr(¶ms), + smem_aux(shared_storage.smem_aux.data()) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template < + class TileShapeMNK + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class TiledR2S, + class STensorR2S, + class STensorS2G, + class GTensorS2G + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + TiledR2S tiled_r2s, + STensorR2S&& tRS_sAux, + STensorS2G&& bSG_sAux, + GTensorS2G&& bSG_gAux, + Params const* params_ptr) + : tiled_r2s(tiled_r2s), + tC_rAux(cute::forward(tC_rAux)), + tRS_sAux(cute::forward(tRS_sAux)), + bSG_sAux(cute::forward(bSG_sAux)), + bSG_gAux(cute::forward(bSG_gAux)), + params_ptr(params_ptr) {} + + TiledR2S tiled_r2s; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorR2S tRS_sAux; // (R2S,R2S_M,R2S_N,PIPE) + STensorS2G bSG_sAux; // (S2G,S2G_M,S2G_N,PIPE) + GTensorS2G bSG_gAux; // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + step_next(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + using RLayoutR2S = decltype(cute::layout(TiledR2S{}.get_slice(0).retile_S(RTensor{}))); + Tensor tRS_rAux = make_tensor(tC_rAux.data(), RLayoutR2S{}); // (R2S,R2S_M,R2S_N) + + if (issue_smem_store) { + int store_pipe_index = store_iteration % Stages; + copy(tiled_r2s, tRS_rAux, tRS_sAux(_,_,_,store_pipe_index)); + } + } + + CUTLASS_DEVICE void + step_end(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + if (issue_tma_store) { + // Issue the TMA store + int store_pipe_index = store_iteration % Stages; + copy(params_ptr->tma_store_aux, bSG_sAux(_,_,_,store_pipe_index), bSG_gAux(_,_,_,epi_m,epi_n)); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + auto [M, N, K, L] = problem_shape_mnkl; + Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gAux = sm90_tensor_to_cta_tile(mAux, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, epi_tile, tiled_copy, thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + Tensor gAux_epi = local_tile(gAux, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, tiled_copy), + make_tiled_copy_D(Copy_Atom{}, tiled_copy) + ); + auto tRS_sAux = tiled_r2s.get_slice(thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) + + ThrCopy thrblk_s2g = params_ptr->tma_store_aux.get_slice(_0{}); + Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), + tiled_r2s, + cute::move(tRS_sAux), + cute::move(bSG_sAux), + cute::move(bSG_gAux), + params_ptr); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Reduction Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar reduction +template < + template class RegReduceFn, + template class AtomicReduceFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_0,_0>, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90ScalarReduction { + static_assert( + (cute::is_same_v>) || // scalar reduction, e.g. tensor max element + (cute::is_same_v>)); // batched scalar reduction, e.g. per-batch max element + + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_scalar = nullptr; + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dScalar = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction(Params const& params, SharedStorage& shared_storage) + : params(params) { } + + Params const params; + + template < + class TileShapeMNK, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + int l_coord, + CTensor&& tCcScalar, + ResidueMN residue_mn, + Params const& params) + : scalar(params.reduction_identity), + l_coord(l_coord), + tCcScalar(cute::forward(tCcScalar)), + residue_mn(residue_mn), + params(params) {} + + ElementCompute scalar; + int l_coord; + CTensor tCcScalar; + ResidueMN residue_mn; + Params params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return frg_input; + } + } + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcScalar(epi_v * FragmentSize + i), residue_mn)) { + scalar = reduce_input(scalar, frg_I[i]); + } + } + + return frg_input; + } + + CUTLASS_DEVICE void + end() { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return; + } + } + + using ConvertI = NumericConverter; + using ReduceInput = AtomicReduceFn; + + ConvertI convert_I{}; + ReduceInput reduce_input{}; + + ElementOutput* ptr_scalar = params.ptr_scalar + l_coord * get<2>(params.dScalar); + reduce_input(ptr_scalar, convert_I(scalar)); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + int l_coord = static_cast(get<3>(tile_coord_mnkl)); + + // Compute tile residues and coordinate tensors for predication + auto [M, N, K, L] = problem_shape_mnkl; + auto [m, n, k, l] = tile_coord_mnkl; + auto residue_mn = make_coord( + M - static_cast(m) * size<0>(tile_shape_mnk), + N - static_cast(n) * size<1>(tile_shape_mnk) + ); + Tensor cScalar = make_identity_tensor(take<0,2>(tile_shape_mnk)); + Tensor tCcScalar = sm90_partition_for_epilogue(cScalar, epi_tile, tiled_copy, thread_idx); + + return ConsumerStoreCallbacks(l_coord, cute::move(tCcScalar), residue_mn, params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Row vector reduction +template < + template class RegReduceFn, + template class AtomicReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90RowReduction { + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // row vector reduction, e.g. per-col sum over all batches + (cute::is_same_v>)); // batched row vector reduction, e.g. per-col sum per batch + + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_row = nullptr; + ElementCompute reduction_identity = 0; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90RowReduction() { } + + CUTLASS_HOST_DEVICE + Sm90RowReduction(Params const& params, SharedStorage& shared_storage) + : params(params) { } + + Params params; + + template < + class TileShapeMNK, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tCrRow, + GTensor&& tCgRow, + CTensor&& tCcRow, + ResidueMN residue_mn, + Params const& params) + : tCrRow(cute::forward(tCrRow)), + tCgRow(cute::forward(tCgRow)), + tCcRow(cute::forward(tCcRow)), + residue_mn(residue_mn), + params(params) {} + + // gmem store after every column of subtiles, assuming M-major loop + // needed to reduce reg pressure, otherwise each thread stores up to a full row in RF + // since row-elements aren't evenly distributed amongst threads + RTensor tCrRow; // (CPY,CPY_M,CPY_N,EPI_M) + GTensor tCgRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ResidueMN residue_mn; + Params const& params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return frg_input; + } + } + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + Tensor tCrRow_mn = tCrRow(_,_,_,epi_m); + Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcRow_mn(i), residue_mn)) { + ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); + tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); + } + } + + return frg_input; + } + + CUTLASS_DEVICE void + step_end(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return; + } + } + + if (epi_m == size<3>(tCrRow)-1) { // assumes M-major subtile loop + using ConvertI = NumericConverter; + using ReduceInput = AtomicReduceFn; + + ConvertI convert_I{}; + ReduceInput reduce_input{}; + + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrRow_flt = filter_zeros(tCrRow(_,_,_,epi_m)); + Tensor tCgRow_flt = filter_zeros(tCgRow(_,_,_,epi_m,epi_n)); + Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); + Tensor tCcRow_flt = make_tensor(tCcRow_mn.data(), make_layout(tCgRow_flt.shape(), tCcRow_mn.stride())); + + + auto [residue_m, residue_n] = residue_mn; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrRow_flt); ++i) { + // partially OOB in M must still issue gmem reduction, so only consider residue_n + // in case last epi tile in column is fully OOB in M and CTA tile is partially OOB in M + if (residue_n > get<1>(tCcRow_flt(i)) && + // fully OOB in M does not need to issue gmem reduction, skip + residue_m > 0) { + reduce_input(&tCgRow_flt(i), convert_I(tCrRow_flt(i))); + } + } + + // Reset the registers to the reduction identity + fill(tCrRow, params.reduction_identity); + } + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + auto [M, N, K, L] = problem_shape_mnkl; + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mRow, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + Tensor tCrRow = make_tensor_like(tCgRow(_,_,_,_,_0{})); // (CPY,CPY_M,CPY_N,EPI_M) + fill(tCrRow, params.reduction_identity); + + // Compute tile residues and coordinate tensors for predication + auto [m, n, k, l] = tile_coord_mnkl; + auto residue_mn = make_coord( + M - static_cast(m) * size<0>(tile_shape_mnk), + N - static_cast(n) * size<1>(tile_shape_mnk) + ); + Tensor cRow = make_identity_tensor(take<0,2>(tile_shape_mnk)); + Tensor tCcRow = sm90_partition_for_epilogue(cRow, epi_tile, tiled_copy, thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tCrRow), cute::move(tCgRow), cute::move(tCcRow), residue_mn, params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Col vector reduction +template < + template class RegReduceFn, + template class AtomicReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90ColReduction { + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert( + (cute::is_same_v>) || // col vector reduction, e.g. per-row sum over all batches + (cute::is_same_v>)); // batched col vector reduction, e.g. per-row sum per batch + + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_col = nullptr; + ElementCompute reduction_identity = 0; + StrideMNL dCol = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ColReduction(Params const& params, SharedStorage& shared_storage) + : params(params) { } + + Params params; + + template < + class TileShapeMNK, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tCrCol, + GTensor&& tCgCol, + CTensor&& tCcCol, + ResidueMN residue_mn, + Params const& params) + : tCrCol(cute::forward(tCrCol)), + tCgCol(cute::forward(tCgCol)), + tCcCol(cute::forward(tCcCol)), + residue_mn(residue_mn), + params(params) {} + + RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ResidueMN residue_mn; + Params const& params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return frg_input; + } + } + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcCol(i), residue_mn)) { + ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); + tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); + } + } + + return frg_input; + } + + CUTLASS_DEVICE void + end() { + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return; + } + } + + using ConvertI = NumericConverter; + using ReduceInput = AtomicReduceFn; + + ConvertI convert_I{}; + ReduceInput reduce_input{}; + + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrCol_flt = filter_zeros(tCrCol); + Tensor tCgCol_flt = filter_zeros(tCgCol); + Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCgCol_flt.shape(), tCcCol.stride())); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrCol_flt); ++i) { + if (elem_less(tCcCol_flt(i), residue_mn)) { + reduce_input(&tCgCol_flt(i), convert_I(tCrCol_flt(i))); + } + } + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + + auto [M, N, K, L] = problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + fill(tCrCol, params.reduction_identity); + + // Compute tile residues and coordinate tensors for predication + auto [m, n, k, l] = tile_coord_mnkl; + auto residue_mn = make_coord( + M - static_cast(m) * size<0>(tile_shape_mnk), + N - static_cast(n) * size<1>(tile_shape_mnk) + ); + Tensor cCol = make_identity_tensor(take<0,2>(tile_shape_mnk)); + Tensor tCcCol = sm90_partition_for_epilogue(cCol, epi_tile, tiled_copy, thread_idx); + + return ConsumerStoreCallbacks(cute::move(tCrCol), cute::move(tCgCol), cute::move(tCcCol), residue_mn, params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix reduction +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class CopyOpR2S, + class SmemLayoutAtom, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90MatrixReduction; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp new file mode 100644 index 0000000000..7750701e18 --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -0,0 +1,827 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree operation base implementation to enable composable fusions + for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using cute::tuple; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Convenience aliases +using ProblemShapeMNKL = tuple; +using TileCoordMNKL = tuple; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partitioning Helpers +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class Engine, class LayoutMNL, + class TileShapeMNK +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_tensor_to_cta_tile( + Tensor mT, // (M,N,L) + TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) + TileCoordMNKL tile_coord_mnkl) { + using _X = Underscore; + + auto [m, n, k, l] = tile_coord_mnkl; + Tensor mT_mnl = local_tile(mT, tile_shape_mnk, make_coord(_,_,_), Step<_1,_1,_X>{}); // (CTA_M,CTA_N) + + return mT_mnl(_,_,m,n,l); +} + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class CtaTileMN, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + CtaTileMN cT, // (CTA_M,CTA_N,...) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + ThrCopy thread_copy = tiled_copy.get_thread_slice(thread_idx); + Tensor cT_epi = local_tile(cT, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) + if constexpr (ReferenceSrc) { + return thread_copy.partition_S(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } + else { + return thread_copy.partition_D(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } +} + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class Engine, class LayoutMNL, + class TileShapeMNK, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + Tensor mT, // (M,N,L) + TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) + TileCoordMNKL tile_coord_mnkl, // (m,n,k,l) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + Tensor cT = sm90_tensor_to_cta_tile(mT, tile_shape_mnk, tile_coord_mnkl); // (CTA_M,CTA_N) + Tensor tCcT = + sm90_partition_for_epilogue(cT, epi_tile, tiled_copy, thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return tCcT; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Visitor Implementation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sm90VisitorImplBase { + // Shared memory allocation + using SharedStorage = tuple; + // Host side fusion arguments + using Arguments = tuple; + // Device side fusion params (Kernel-entry API) + using Params = tuple; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + return Op::to_underlying_arguments(problem_shape, op_args, workspace); + }, + [] (auto&&... op_params) { return cute::make_tuple(op_params...); } + ); + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + : ops(transform_apply(tuple{}, params, shared_storage, + [] (auto&& op, auto const& op_params, auto&& op_storage) { + using Op = cute::remove_cvref_t; + return Op(op_params, op_storage); + }, + [] (auto&&... ops) { return cute::make_tuple(ops...); } + )) {} + + // Ops can store kernel persistent variables (e.g. descriptors, scalars, wave counters) + tuple ops; +}; + + +template +struct Sm90VisitorImpl : Sm90VisitorImplBase { + + using Sm90VisitorImplBase::Sm90VisitorImplBase; + using Sm90VisitorImplBase::ops; + + // + // Queries for kernel runtime + // + + // Is a specialized warp for producer TMA loads needed + // e.g. Aux tensor loads, broadcasts using TMA bulk copy + // This condition cannot change between work tiles because it is used + // to determine whether the load warp should exit early or not + // e.g. for batched beta this must always be true regardless of current batch idx + CUTLASS_DEVICE bool + is_producer_load_needed() const { + bool needed = false; + for_each(ops, + [&] (auto const& op) { + needed |= op.is_producer_load_needed(); + } + ); + return needed; + } + + // Is a producer TMA load specifically for C needed + // If this is true then is_producer_load_needed must also be true + // This condition can change between work tiles because it is only used + // to determine whether the TMA and smem loads for C of a given tile should happen + // e.g. for batched beta this can be false depending on current batch idx + CUTLASS_DEVICE bool + is_C_load_needed() const { + bool needed = false; + for_each(ops, + [&] (auto const& op) { + needed |= op.is_C_load_needed(); + } + ); + return needed; + } + + // + // Producer load callbacks, called by the epilogue load warp. + // Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation + // Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but + // are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead + // If this is non-empty, is_producer_load_needed must be true. + // + template + struct ProducerLoadCallbacks { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of the subtile load loop. Bulk copies usually performed here. + // Upon entry the producer_acquire of the first subtile lock has completed. + // full_mbarrier_ptr is the corresponding barrier for the subsequent producer_commit arrival + CUTLASS_DEVICE void + begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.begin(full_mbarrier_ptr, load_iteration, issue_tma_load); + } + ); + } + + // Entry of the subtile load loop. Aux loads usually performed here + // Upon entry the producer acquire of the current subtile lock has completed. + // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); + } + ); + } + + // Exit of the subtile load loop. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [] (auto& callbacks) { + callbacks.end(); + } + ); + } + }; + + // Producer load callbacks factory + // All operations must redefine this, but most can just dispatch to the base impl + template < + class TileShapeMNK, + class EpilogueTile + > + CUTLASS_DEVICE auto + get_producer_load_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + int thread_idx) { + return transform_apply(ops, + [&] (auto& op) { + return op.get_producer_load_callbacks( + problem_shape_mnkl, + tile_shape_mnk, + tile_coord_mnkl, + epi_tile, + thread_idx + ); + }, + [] (auto&&... callbacks) { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ProducerLoadCallbacks{callbacks_tuple}; + } + ); + } + + // + // Consumer store callbacks, called by the epilogue store warps. + // All operations must redefine this, with optional inheritance from this empty implementation. + // + template + struct ConsumerStoreCallbacks { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of subtile store loop. Gmem broadcasts usually performed here. + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [] (auto& callbacks) { + callbacks.begin(); + } + ); + } + + // Start of subtile store iteration. Smem broadcasts usually performed here. + // Upon entry, all producer loads for this subtile are completed and visible. + CUTLASS_DEVICE void + step_begin(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.step_begin(epi_m, epi_n, load_iteration, is_producer_load_needed); + } + ); + } + + // Perform the fused elementwise computation + template + CUTLASS_DEVICE auto // returns an Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) // depends on the N-naryness of the op + = delete; // Must be implemented for each operation + + // After D smem store, before smem async fence. Smem reductions usually performed here. + // Upon exit, all smem stores for TMA must have been issued + CUTLASS_DEVICE void + step_next(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.step_next(epi_m, epi_n, store_iteration, issue_smem_store); + } + ); + } + + // End of subtile iteration, before TMA store commit. Aux stores usually performed here + // Upon exit, all TMA stores for this subtile must have been issued + CUTLASS_DEVICE void + step_end(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.step_end(epi_m, epi_n, store_iteration, issue_tma_store); + } + ); + } + + // Exit of subtile store loop. Gmem reductions usually performed here. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.end(); + } + ); + } + }; + + // Consumer store callbacks factory + // All operations must redefine this + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + return transform_apply(ops, + [&] (auto& op) { + return op.template get_consumer_store_callbacks( + problem_shape_mnkl, + tile_shape_mnk, + tile_coord_mnkl, + epi_tile, + tiled_copy, + thread_idx, + tCrC + ); + }, + [] (auto&&... callbacks) { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ConsumerStoreCallbacks{callbacks_tuple}; + } + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Convenience aliases +using EmptyProducerLoadCallbacks = Sm90VisitorImpl<>::ProducerLoadCallbacks>; +using EmptyConsumerStoreCallbacks = Sm90VisitorImpl<>::ConsumerStoreCallbacks>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tree visitor +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sm90TreeVisitor : Sm90VisitorImpl { + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(ChildOps); + return cute::detail::tapply(callbacks_tuple, + [&] (auto& child_callbacks) { + return child_callbacks.visit(frg_acc, epi_v, epi_m, epi_n); // child ops must be nullary (e.g. loads, trees) + }, + [&] (auto&&... frg_inputs) { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + make_seq{} // restrict the transform to R-1 child ops, apply is for node op + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + return ConsumerStoreCallbacks( + Sm90VisitorImpl:: + get_consumer_store_callbacks( + problem_shape_mnkl, + tile_shape_mnk, + tile_coord_mnkl, + epi_tile, + tiled_copy, + thread_idx, + tCrC + ) + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// DAG visitors +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Most DAG fusions can be represented as a set of output trees with a common input tree +// The common input is first evaluated, then the result is passed as the acc fragment to the output trees +template +struct Sm90SplitTreeVisitor : Sm90VisitorImpl { + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_input = get<0>(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + constexpr int Rm2 = sizeof...(AuxOutTrees); + cute::detail::for_sequence(make_seq{}, // restrict the sequence to aux out trees + [&] (auto&& _I) { + constexpr int i = remove_cvref_t::value; + get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + ); + + return get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + return ConsumerStoreCallbacks( + Sm90VisitorImpl:: + get_consumer_store_callbacks( + problem_shape_mnkl, + tile_shape_mnk, + tile_coord_mnkl, + epi_tile, + tiled_copy, + thread_idx, + tCrC + ) + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // deducing the output type for all the nodes is tricky so we just convert them all to a common type + // if multiple compute types are needed then split into multiple subgraphs grouped by type + class ElementCompute, + class EdgeTuple, // tuple of int_sequence, each sequence is the children indices (indexed by topological order) for each node + class... Ops // in topological order, last op is the output. EdgeTuple must match this order +> +struct Sm90TopologicalVisitor : Sm90VisitorImpl { + static_assert(is_static_v); + static_assert(rank(EdgeTuple{}) == sizeof...(Ops)); + static_assert(sizeof...(Ops) > 1); + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(Ops) - 1; + auto frg_compute_tuple = cute::repeat(Array{}); + + return cute::detail::tapply(EdgeTuple{}, callbacks_tuple, frg_compute_tuple, + // Visit the first R-1 ops in topological order + [&] (auto&& edge_seq, auto& callbacks, auto& frg_compute) { + frg_compute = cute::detail::apply(frg_compute_tuple, + // Compute the current op with children inputs + [&] (auto const&... frg_inputs) { + auto frg_output = callbacks.visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + using ElementOutput = typename decltype(frg_output)::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + + return convert_output(frg_output); + }, + // Get inputs in the sequence given by the children indices of the current op + edge_seq + ); + return frg_compute; // unused + }, + // Visit the last op + [&] (auto const&...) { + return cute::detail::apply(frg_compute_tuple, + // Compute the last op with children inputs + [&] (auto const&... frg_inputs) { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + // Get inputs in the sequence given by the children indices of the last op + get(EdgeTuple{}) + ); + }, + // Transform to visit R-1 ops, apply to visit last op + make_seq{} + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class TileShapeMNK, + class EpilogueTile, + class TiledCopy, + class SrcTensor + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + SrcTensor const& tCrC) { + return ConsumerStoreCallbacks( + Sm90VisitorImpl:: + get_consumer_store_callbacks( + problem_shape_mnkl, + tile_shape_mnk, + tile_coord_mnkl, + epi_tile, + tiled_copy, + thread_idx, + tCrC + ) + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Base specializations so we can have standard layout params and simple aggregate initializers +namespace detail { + +template +struct Sm90VisitorImplBase { + + struct SharedStorage { + typename Op0::SharedStorage op_0; + }; + + struct Arguments { + typename Op0::Arguments op_0; + }; + + struct Params { + typename Op0::Params op_0; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, workspace) + }; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + : ops({ + Op0(params.op_0, shared_storage.op_0) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + struct SharedStorage { + typename Op0::SharedStorage op_0; + typename Op1::SharedStorage op_1; + }; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, workspace) + }; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + : ops({ + Op0(params.op_0, shared_storage.op_0), + Op1(params.op_1, shared_storage.op_1) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + struct SharedStorage { + typename Op0::SharedStorage op_0; + typename Op1::SharedStorage op_1; + typename Op2::SharedStorage op_2; + }; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, workspace) + }; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + : ops({ + Op0(params.op_0, shared_storage.op_0), + Op1(params.op_1, shared_storage.op_1), + Op2(params.op_2, shared_storage.op_2) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + struct SharedStorage { + typename Op0::SharedStorage op_0; + typename Op1::SharedStorage op_1; + typename Op2::SharedStorage op_2; + typename Op3::SharedStorage op_3; + }; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + typename Op3::Arguments op_3; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + typename Op3::Params op_3; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, workspace), + Op3::to_underlying_arguments(problem_shape, args.op_3, workspace) + }; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage& shared_storage) + : ops({ + Op0(params.op_0, shared_storage.op_0), + Op1(params.op_1, shared_storage.op_1), + Op2(params.op_2, shared_storage.op_2), + Op3(params.op_3, shared_storage.op_3) + }) {} + + tuple ops; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 79c6072c15..526d46b569 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -144,19 +144,22 @@ struct ReLu { } }; +template +using ReLU = ReLu; + template struct ReLu> { static const bool kIsHeavy=false; CUTLASS_HOST_DEVICE Array operator()(T const & threshold, Array const &frag) const { - maximum > mx; + maximum> mx; return mx(frag, threshold); } CUTLASS_HOST_DEVICE Array operator()(Array const &frag) const { - maximum > mx; + maximum> mx; return mx(frag, T(0)); } @@ -531,7 +534,7 @@ struct GELU { CUTLASS_HOST_DEVICE float operator()(float const &scalar) const { return cutlass::constants::half() * scalar * - (cutlass::constants::one() + erff( scalar * cutlass::constants::half_root_two() )); + (cutlass::constants::one() + erff(scalar * cutlass::constants::half_root_two() )); } using Params = LinearCombinationGenericParams; diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index 918f3301e8..3880204a13 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -72,6 +72,7 @@ class LinearCombination { using ElementSource = ElementSource_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + using ElementScalar = ElementCompute; using ElementC = ElementSource_; using ElementD = ElementOutput_; @@ -164,49 +165,6 @@ class LinearCombination { } } - /// Computes intermediate: X = beta * source - CUTLASS_HOST_DEVICE - FragmentCompute compute_intermediate( - FragmentSource const &source) const { - - // Convert source to internal compute numeric type - NumericArrayConverter source_converter; - FragmentCompute converted_source = source_converter(source); - - if (Scale == ScaleType::NoBetaScaling) { - return converted_source; - } - else { - multiplies mul_source; - return mul_source(beta_, converted_source); - } - } - - /// Computes linear scaling with intermediate: D = alpha * accumulator + X - CUTLASS_HOST_DEVICE - FragmentOutput with_intermediate( - FragmentAccumulator const& accumulator, - FragmentCompute const& intermediate) const { - - // Convert accumulator to internal compute numeric type - NumericArrayConverter accumulator_converter; - - // Convert to destination numeric type - NumericArrayConverter destination_converter; - - FragmentCompute converted_accumulator = accumulator_converter(accumulator); - - if (Scale == ScaleType::Nothing) { - return destination_converter(converted_accumulator); - } else { - // Perform binary operations - multiply_add mul_add_accumulator; - FragmentCompute computed_output = mul_add_accumulator(alpha_, converted_accumulator, intermediate); - - return destination_converter(computed_output); - } - } - /// Computes linear scaling with source: D = alpha * accumulator + beta * source CUTLASS_HOST_DEVICE FragmentOutput operator()( @@ -321,6 +279,196 @@ class LinearCombination { } }; +/// Applies a linear combination operator to an array of elements. +/// +/// D = vector_alpha * accumulator + (optional) vector_beta/scalar_beta * source +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + FloatRoundStyle Round, + typename ElementSource_ +> +class LinearCombination { +public: + + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementC = ElementSource_; + using ElementD = ElementOutput_; + + static int const kCount = Count; + static const ScaleType::Kind kScale = ScaleType::PerChannelScaling; + static constexpr bool IsPerChannelScalingSupported = true; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params + { + ElementCompute const *alpha_ptr; ///< pointer to accumulator vector + ElementCompute const *beta_ptr; ///< pointer to source vector + ElementCompute beta; ///< scales source tensor + + CUTLASS_HOST_DEVICE + Params(): + alpha_ptr(nullptr), + beta_ptr(nullptr), + beta(ElementCompute(0)) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr + ): + alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), beta(ElementCompute(0)) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr + ): + alpha_ptr(alpha_ptr), beta_ptr(nullptr), beta(ElementCompute(0)) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute beta + ): + alpha_ptr(alpha_ptr), beta_ptr(nullptr), beta(beta) { } + + }; + +private: + + // + // Data members + // + + ElementCompute const* beta_ptr_ = nullptr; + ElementCompute beta_ = 0; + +public: + + /// Constructs the function object + CUTLASS_HOST_DEVICE + LinearCombination(Params const& params) { + if (params.beta_ptr) { + beta_ptr_ = params.beta_ptr; + } + else { + beta_ = params.beta; + } + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ptr_ != nullptr || beta_ != ElementCompute(0); + } + + CUTLASS_HOST_DEVICE + bool is_beta_vector() const { + return beta_ptr_ != nullptr; + } + + /// Computes linear scaling with source: D = vector_alpha * accumulator + vector_beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& accumulator, + FragmentSource const& source, + FragmentCompute const& valpha, + FragmentCompute const& vbeta) const { + // Convert source to internal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + intermediate = mul_add_source(vbeta, converted_source); // X = vector_beta * C + uniform + + intermediate = mul_add_accumulator(valpha, converted_accumulator, intermediate); // D = vector_alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling with source: D = vector_alpha * accumulator + scalar_beta(from host) * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& accumulator, + FragmentSource const& source, + FragmentCompute const& valpha) const { + // Convert source to internal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + FragmentCompute converted_source = source_converter(source); + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + + intermediate = mul_add_source(beta_, converted_source); // X = scalar_beta * C + uniform + + intermediate = mul_add_accumulator(valpha, converted_accumulator, intermediate); // D = vector_alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = vector_alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& accumulator, + FragmentCompute const& valpha) const { + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + FragmentCompute converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + FragmentCompute intermediate; + multiplies mul_accumulator; + + intermediate = mul_accumulator(valpha, converted_accumulator); // D = vector_alpha * Accum + + return destination_converter(intermediate); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace thread diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 7970b5f7fc..15cc10f483 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -40,6 +40,7 @@ #include "cutlass/array.h" #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" +#include "cutlass/platform/platform.h" #include "cutlass/epilogue/thread/activation.h" #include "cutlass/epilogue/thread/scale_type.h" @@ -52,6 +53,18 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// +// If kIsHeavy is a member, use it. Otherwise, assume that it's false. +namespace { // (anonymous) +template +struct kIsHeavy_member_or_false { + static constexpr bool value = false; +}; +template +struct kIsHeavy_member_or_false::type> { + static constexpr bool value = Op::kIsHeavy; +}; +} // namespace (anonymous) + /// This base class is meant to define the concept required of the /// EpilogueWithBroadcast::OutputOp template < @@ -99,7 +112,7 @@ class LinearCombinationBiasElementwise { using ActivationFunctor = ElementwiseOp; static const ScaleType::Kind kScale = ScaleType::Default; - static bool const kIsHeavy = ElementwiseOp::kIsHeavy; + static bool const kIsHeavy = kIsHeavy_member_or_false::value; /// If true, the 'Z' tensor is stored static bool const kStoreZ = true; diff --git a/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp b/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp index c89f28895b..a5ede4e628 100644 --- a/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp +++ b/include/cutlass/epilogue/thread/linear_combination_tensor_broadcast.hpp @@ -119,6 +119,7 @@ class LinearCombinationTensorBroadcast { using ElementOutput = ElementOutput_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; + using ElementScalar = ElementCompute; using ElementBias = ElementBias_; using ElementC = ElementSource_; using ElementD = ElementOutput_; diff --git a/include/cutlass/epilogue/thread/scale_type.h b/include/cutlass/epilogue/thread/scale_type.h index f2299277a9..190852ec87 100644 --- a/include/cutlass/epilogue/thread/scale_type.h +++ b/include/cutlass/epilogue/thread/scale_type.h @@ -45,13 +45,17 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Specifies internal data type for computation +/// Note : +/// 1. Scalar means alpha/beta is a single value from host(constant param) or device memory. +/// 2. Vector means alpha/beta is a vector always from device memory. struct ScaleType { enum Kind { - Default, // alpha x C + beta x D - NoBetaScaling, // alpha x C + D - OnlyAlphaScaling, // alpha x C - OnlyAlphaPerChannelScaling, // alpha_vec x C - Nothing // C + Default, // D = scalar_alpha x Acc + scalar_beta x C + NoBetaScaling, // D = scalar_alpha x Acc + C + OnlyAlphaScaling, // D = scalar_alpha x Acc + PerChannelScaling, // D = vector_alpha x Acc + vector_beta x C + OnlyAlphaPerChannelScaling, // D = vector_alpha x Acc + Nothing // D = Acc }; }; diff --git a/include/cutlass/epilogue/threadblock/epilogue.h b/include/cutlass/epilogue/threadblock/epilogue.h index 61d961df3e..42ca5573e8 100644 --- a/include/cutlass/epilogue/threadblock/epilogue.h +++ b/include/cutlass/epilogue/threadblock/epilogue.h @@ -197,6 +197,10 @@ class Epilogue : SourceAspectNotNeeded() {} + // No-op + CUTLASS_DEVICE + void load() { } + /// Invoke the output functor over each vector of output CUTLASS_DEVICE void apply_output_operator( @@ -266,6 +270,13 @@ class Epilogue : source_fragment.clear(); } + // Load addend source fragment from global memory + CUTLASS_DEVICE + void load() { + source_iterator.load(source_fragment); + ++source_iterator; + } + /// Invoke the output functor over each vector of output CUTLASS_DEVICE void apply_output_operator( @@ -273,10 +284,6 @@ class Epilogue : OutputOp const &output_op, typename SharedLoadIterator::Fragment const &aligned_accum_fragment) { - // Load addend source fragment from global memory - source_iterator.load(source_fragment); - ++source_iterator; - apply_output_operator(output_fragment, output_op, aligned_accum_fragment, source_fragment); } }; @@ -439,23 +446,11 @@ class Epilogue : ++accum_fragment_iterator; } - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - typename AccumulatorFragmentIterator::Fragment accum_fragment; + typename AccumulatorFragmentIterator::Fragment accum_fragment; - accum_fragment_iterator.load(accum_fragment); - ++accum_fragment_iterator; - - warp_tile_iterator.store(accum_fragment); - if (p < Base::kFragmentsPerIteration - 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); - } - } - - if (Base::kFragmentsPerIteration > 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * - (1 - Base::kFragmentsPerIteration)); - } + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + warp_tile_iterator.store(accum_fragment); } CUTLASS_DEVICE @@ -483,10 +478,14 @@ class Epilogue : // Iterate over accumulator tile // - #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; iter += Base::kFragmentsPerIteration) + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + // + // Load the source + // + source.load(); // // Convert and store fragment // @@ -496,32 +495,23 @@ class Epilogue : acc2smem>::push( iter, accum_fragment_iterator, this->warp_tile_iterator_); + __syncthreads(); + // // Load fragments from shared memory // - __syncthreads(); + typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; + shared_load_iterator_.load(aligned_accum_fragment[0]); - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) - { - typename SharedLoadIterator::Fragment aligned_accum_fragment; - shared_load_iterator_.load(aligned_accum_fragment); + if (kPartitionsK > 1) { + plus add_fragments; - if (p < Base::kFragmentsPerIteration - 1) - { + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - } - else if (kPartitionsK > 1) - { - plus add_fragments; - - CUTLASS_PRAGMA_UNROLL - for ( int i = 1; i < kPartitionsK; ++i) { - typename SharedLoadIterator::Fragment aligned_accum_fragment_addend; - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); - shared_load_iterator_.load(aligned_accum_fragment_addend); - aligned_accum_fragment = add_fragments(aligned_accum_fragment, aligned_accum_fragment_addend); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]); } shared_load_iterator_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); @@ -532,7 +522,7 @@ class Epilogue : // typename OutputTileIterator::Fragment output_fragment; - source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment); + source.apply_output_operator(output_fragment, output_op, aligned_accum_fragment[0]); // // Store the final result @@ -540,14 +530,8 @@ class Epilogue : destination_iterator.store(output_fragment); ++destination_iterator; - } - - if (Base::kFragmentsPerIteration > 1) { - shared_load_iterator_.add_pointer_offset(kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); - } } } - }; //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h index 828b7a7200..764f9cf937 100644 --- a/include/cutlass/epilogue/threadblock/output_tile_thread_map.h +++ b/include/cutlass/epilogue/threadblock/output_tile_thread_map.h @@ -394,7 +394,8 @@ struct OutputTileOptimalThreadMap { CUTLASS_DEVICE static MatrixCoord initial_offset(int thread_idx) { - int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); +// int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); + int warp_idx = thread_idx / kWarpSize; int lane_idx = thread_idx % kWarpSize; // Compute warp location @@ -464,7 +465,8 @@ struct OutputTileOptimalThreadMap { CUTLASS_DEVICE static MatrixCoord initial_offset(int thread_idx) { - int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); +// int warp_idx = __shfl_sync(0xffffffff, thread_idx / kWarpSize, 0); + int warp_idx = thread_idx / kWarpSize; int lane_idx = thread_idx % kWarpSize; // Compute warp location diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index 4da07d4526..95b087aeed 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -640,23 +640,25 @@ class PredicatedTileIterator { ++state_[0]; - if (!ScatterD && !PermuteD) { - store_byte_pointer_ += params_.advance_row; - } - if (!ScatterD) { byte_pointer_ += params_.advance_row; } + if (!ScatterD && !PermuteD) { + store_byte_pointer_ += params_.advance_row; + } + thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; + if (!ScatterD) { byte_pointer_ += params_.advance_group; } + if (!ScatterD && !PermuteD) { store_byte_pointer_ += params_.advance_group; } @@ -668,9 +670,11 @@ class PredicatedTileIterator { state_[1] = 0; ++state_[2]; + if (!ScatterD) { byte_pointer_ += params_.advance_cluster; } + if (!ScatterD && !PermuteD) { store_byte_pointer_ += params_.advance_cluster; } @@ -680,9 +684,11 @@ class PredicatedTileIterator { if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; + if (!ScatterD) { byte_pointer_ += params_.advance_tile; } + if (!ScatterD && !PermuteD) { store_byte_pointer_ += params_.advance_tile; } diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h index 937409afc0..33bf30a51e 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -291,10 +291,20 @@ struct PredicatedTileIteratorDirect2dConvParams{ // Fastdivmod for output O, P, Q if(threadblock_output_shape.row() != 0 && threadblock_output_shape.column() !=0 ){ + // MSVC emits a "potential divide by 0" warning as error + // if the code just divides without a check and substitution. + + CUTLASS_ASSERT(threadblock_output_shape.row() != 0); + const auto row_denom = threadblock_output_shape.row() != 0 ? + threadblock_output_shape.row() : cutlass::MatrixCoord::Index(1); int tiles_p = - (problem_size.P + (threadblock_output_shape.row() - 1)) / (threadblock_output_shape.row()); + (problem_size.P + (threadblock_output_shape.row() - 1)) / row_denom; + + CUTLASS_ASSERT(threadblock_output_shape.column() != 0); + const auto col_denom = threadblock_output_shape.column() != 0 ? + threadblock_output_shape.column() : cutlass::MatrixCoord::Index(1); int tiles_q = (problem_size.Q + (threadblock_output_shape.column() - 1)) / - (threadblock_output_shape.column()); + col_denom; pq_divmod = FastDivmod(tiles_p * tiles_q); q_divmod = FastDivmod(tiles_q); diff --git a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h index a4cabd7179..d2de082503 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h @@ -416,7 +416,7 @@ struct TileIteratorVoltaTensorOp, float, AccessType *frag_ptr = reinterpret_cast(&frag); - assert(0); // TODO + assert(0); } /// Load diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index c449def394..e1821f1efd 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -152,7 +152,7 @@ CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divi * Greatest common divisor */ template -CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) { +CUTLASS_HOST_DEVICE constexpr value_t gcd(value_t a, value_t b) { for (;;) { if (a == 0) return b; b %= a; @@ -165,7 +165,7 @@ CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) { * Least common multiple */ template -CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) { +CUTLASS_HOST_DEVICE constexpr value_t lcm(value_t a, value_t b) { value_t temp = gcd(a, b); return temp ? (a / temp * b) : 0; @@ -459,7 +459,6 @@ struct FastDivmodU64 { } quotient = (x >> shift_right); #else - // TODO - use proper 'fast' division here also. No reason why x86-code shouldn't be optimized. quotient = dividend / divisor; #endif diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 0eee0fd150..c97abf3648 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -75,7 +75,6 @@ #include #include "cutlass/cutlass.h" - /////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -1080,7 +1079,7 @@ struct numeric_limits : /// Minimum finite value static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } - /// Returns smallest finite value + /// Machine epsilon, that is, the difference between 1.0 and the next representable value static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } }; @@ -1093,7 +1092,7 @@ struct numeric_limits : /// Minimum finite value static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } - /// Returns smallest finite value + /// Machine epsilon, that is, the difference between 1.0 and the next representable value static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } }; @@ -1161,7 +1160,7 @@ struct numeric_limits : /// Minimum finite value static cutlass::float_e4m3_t lowest() { return cutlass::float_e4m3_t::bitcast(0xfe); } - /// Returns smallest finite value + /// Machine epsilon, that is, the difference between 1.0 and the next representable value static cutlass::float_e4m3_t epsilon() { return cutlass::float_e4m3_t::bitcast(0x20); } }; @@ -1174,7 +1173,7 @@ struct numeric_limits : /// Minimum finite value static cutlass::float_e5m2_t lowest() { return cutlass::float_e5m2_t::bitcast(0xfb); } - /// Returns smallest finite value + /// Machine epsilon, that is, the difference between 1.0 and the next representable value static cutlass::float_e5m2_t epsilon() { return cutlass::float_e5m2_t::bitcast(0x34); } }; diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index b8b79e316f..227ce2e159 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -94,6 +94,20 @@ struct multiplies { } }; +template +struct scale { + T const scaling_factor_; + + CUTLASS_HOST_DEVICE + scale(float scaling_factor) : scaling_factor_(scaling_factor) { + } + + T operator()(T const &rhs) const { + T result = rhs * scaling_factor_; + return result; + } +}; + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530 /// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set template<> @@ -147,36 +161,6 @@ struct multiplies<__half> { #endif // defined(__CUDA_ARCH__) -// Maximum with nan propogation -// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN -template -struct maximum_with_nan_propogation { - CUTLASS_HOST_DEVICE - T operator()(T const &lhs, T const &rhs) const { -#if defined(__CUDA_ARCH__) - return lhs > rhs or isnan(lhs) ? lhs : rhs; -#else - return lhs > rhs or std::isnan(lhs) ? lhs : rhs; -#endif - } -}; - -template <> -struct maximum_with_nan_propogation { - CUTLASS_HOST_DEVICE - float operator()(float const lhs, float const rhs) const { - float res; -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); -#elif defined(__CUDA_ARCH__) - res = lhs > rhs or isnan(lhs) ? lhs : rhs; -#else - res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; -#endif - return res; - } -}; - /// Squares with optional conversion template struct square { @@ -280,40 +264,106 @@ struct less { } }; -template +template struct maximum { - CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { return (lhs < rhs ? rhs : lhs); } }; +// Maximum with nan propogation +// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN +template +struct maximum { + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { +#if defined(__CUDA_ARCH__) + return lhs > rhs or isnan(lhs) ? lhs : rhs; +#else + return lhs > rhs or std::isnan(lhs) ? lhs : rhs; +#endif + } +}; + template <> -struct maximum { +struct maximum { CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const { return fmaxf(lhs, rhs); } }; +template <> +struct maximum { + CUTLASS_HOST_DEVICE + float operator()(float const lhs, float const rhs) const { + float res; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs)); +#elif defined(__CUDA_ARCH__) + res = lhs > rhs or isnan(lhs) ? lhs : rhs; +#else + res = lhs > rhs or std::isnan(lhs) ? lhs : rhs; +#endif + return res; + } +}; + template -struct minimum { +using maximum_with_nan_propogation = maximum; +template +struct minimum{ CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { return (rhs < lhs ? rhs : lhs); } }; +template +struct minimum { + CUTLASS_HOST_DEVICE + T operator()(T const &lhs, T const &rhs) const { +#if defined(__CUDA_ARCH__) + return lhs < rhs or isnan(lhs) ? lhs : rhs; +#else + return lhs < rhs or std::isnan(lhs) ? lhs : rhs; +#endif + } +}; + template <> -struct minimum { +struct minimum { CUTLASS_HOST_DEVICE float operator()(float const &lhs, float const &rhs) const { return fminf(lhs, rhs); } }; +template +struct maximum_absolute_value { + CUTLASS_HOST_DEVICE + float operator()(T const &lhs, T const &rhs) const { + absolute_value_op abs_op; + maximum max_op; + + return max_op(abs_op(lhs), abs_op(rhs)); + } +}; + +// assumes the left operand is already an absolute value +template +struct maximum_absolute_value_reduction { + CUTLASS_HOST_DEVICE + float operator()(T const &lhs, T const &rhs) const { + absolute_value_op abs_op; + maximum max_op; + + return max_op(lhs, abs_op(rhs)); + } +}; + /// Fused multiply-add template struct multiply_add { @@ -360,6 +410,14 @@ struct conjugate { } }; +template +struct first { + CUTLASS_HOST_DEVICE + T operator()(T const & first, T const &...) const { + return first; + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template @@ -423,22 +481,22 @@ struct bit_xor { ////////////////////////////////////////////////////////////////////////////////////////////////// +/// Atomic reductions -/// Reduces value into the data pointed to by ptr template -struct red +struct atomic_add { CUTLASS_DEVICE void operator()(T *ptr, const T &data) { +#if defined(__CUDA_ARCH__) atomicAdd(ptr, data); +#endif } }; - -/// Reduces value into the data pointed to by ptr (double specialization) template<> -struct red +struct atomic_add { CUTLASS_DEVICE void operator()(double *ptr, const double &data) @@ -447,11 +505,8 @@ struct red CUTLASS_UNUSED(ptr); CUTLASS_UNUSED(data); #elif (__CUDA_ARCH__ >= 600) - atomicAdd(ptr, data); - #else - // Use CAS loop unsigned long long int* ptr_int = reinterpret_cast(ptr); unsigned long long int old_int = *ptr_int; @@ -462,15 +517,12 @@ struct red assumed_int = old_int; old_int = atomicCAS(ptr_int, assumed_int, __double_as_longlong(update)); } while (assumed_int != old_int); - #endif // (__CUDA_ARCH__ >= 600) } }; - -/// Reduces value into the data pointed to by ptr (half2 specialization) template<> -struct red +struct atomic_add { CUTLASS_DEVICE void operator()(half2 *ptr, const half2 &data) @@ -479,15 +531,48 @@ struct red CUTLASS_UNUSED(ptr); CUTLASS_UNUSED(data); #else - // Vector-2 atomic reduction requires .target sm_60 or higher uint32_t word = reinterpret_cast(data); asm volatile ("red.gpu.global.add.noftz.f16x2 [%0], %1;\n" : : "l"(ptr), "r"(word)); - #endif // (__CUDA_ARCH__ >= 600) } }; +template +using red [[deprecated("use atomic_add instead")]] = atomic_add; + +template +struct atomic_maximum { + CUTLASS_DEVICE + T operator()(T *ptr, T value) const { +#if defined(__CUDA_ARCH__) + return atomicMax(ptr, value); +#else + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(value); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + +template <> +struct atomic_maximum { + CUTLASS_DEVICE + float operator()(float *ptr, float value) const { +#if defined(__CUDA_ARCH__) + return !signbit(value) ? + __int_as_float(atomicMax((int*)ptr, __float_as_int(value))) : + __uint_as_float(atomicMin((unsigned int*)ptr, __float_as_uint(value))); +#else + CUTLASS_UNUSED(ptr); + CUTLASS_UNUSED(value); + CUTLASS_NOT_IMPLEMENTED(); + return 0; +#endif + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/gemm/collective/builders/sm90_common.inl b/include/cutlass/gemm/collective/builders/sm90_common.inl new file mode 100644 index 0000000000..b9c76a6432 --- /dev/null +++ b/include/cutlass/gemm/collective/builders/sm90_common.inl @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/layout.hpp" + +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// Some named constants +// +constexpr int tma_alignment_bytes = 16; +constexpr int cp_async_min_alignment_bytes = 4; +constexpr int sm90_smem_capacity_bytes = 232448; + +// Maps 2.x A matrix layout tag to respective GMMA major mode enum +template +constexpr cute::GMMA::Major +gmma_ss_tag_to_major_A() { + // MN major mode is only valid for non-TF32, non-int and non-fp8 MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_A() && + not cute::is_same_v && + sizeof(ElementA) != 1) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +// Maps 2.x B matrix layout tag to respective GMMA major mode enum +template +constexpr cute::GMMA::Major +gmma_ss_tag_to_major_B() { + // MN major mode is only valid for non-TF32, non-int and non-fp8 MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_B() && + not cute::is_same_v && + sizeof(ElementB) != 1) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +template +constexpr cute::GMMA::Major +gmma_rs_tag_to_major_A() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_A()) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} + +template +constexpr cute::GMMA::Major +gmma_rs_tag_to_major_B() { + // MN major mode is only valid for non-TF32 and non-int MMAs + if constexpr (cutlass::gemm::detail::is_mn_major_B()) { + return cute::GMMA::Major::MN; + } + else { + return cute::GMMA::Major::K; + } +} +// Maps a rank-1 cute::Shape<> representing the cluster shape on to the TMA atom that should be used with it +template +constexpr auto +sm90_cluster_shape_to_tma_atom(UnimodalClusterShape) { + static_assert(cute::rank(UnimodalClusterShape{}) == 1, + "Use this function to figure out TMA for each mode individually."); + + if constexpr (cute::size(UnimodalClusterShape{}) == 1) { + return cute::SM90_TMA_LOAD{}; + } + else { + return cute::SM90_TMA_LOAD_MULTICAST{}; + } +} + +// Generates the most efficient possible TiledCopy with cp.async copy atom given a set of parameters. +template +constexpr auto +make_cp_async_gmem_tiled_copy() { + using AlignmentType = cute::uint_byte_t(sizeof(Element)) * Alignment>; + constexpr int TileSizeMN = cute::size(TileMN{}); + constexpr int TileSizeK = cute::size(TileK{}); + + // Maximize the number of threads along the gmem major mode to promote coalesced reads + // While making sure our thread layout tiles the threadblock tile evenly + + if constexpr (cutlass::gemm::detail::is_k_major()) { + // K major thread layout for K major gmem + constexpr int threads_major = TileSizeK / Alignment; + constexpr int threads_minor = ThreadCount / threads_major; + static_assert(threads_major > 0); + static_assert(ThreadCount % threads_major == 0); + static_assert(threads_minor == 0 || (TileSizeMN % threads_minor == 0)); + return make_tiled_copy( + Copy_Atom, Element>{}, + Layout,Int>, + Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) { + // MN major thread layout for MN major gmem + constexpr int threads_major = TileSizeMN / Alignment; + constexpr int threads_minor = ThreadCount / threads_major; + static_assert(threads_major > 0); + static_assert(ThreadCount % threads_major == 0); + static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0)); + return make_tiled_copy( + Copy_Atom, Element>{}, + Layout,Int>, + Stride< _1,Int>>{}, + Layout,_1>>{}); + } + else { + static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); + } +} + +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the optimal GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +constexpr auto +rs_smem_selector() { + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + if constexpr (major == GMMA::Major::MN) { + if constexpr (sizeof(ElementType) == 4){ + if constexpr (is_ws_transposed_B) { + // only optimized transpositionB(SW32 and SW128 for tf32) can be used, but prefer SW32 due to free bank conflict + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_SW32_Atom{})"); + } + } + else { + // Fall into SW32 due to free bank conflict + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + } + // Used for int8, fp8, fp16 and bf16 I/O kernels + else if constexpr (sizeof(ElementType) == 1 || sizeof(ElementType) == 2) { + if constexpr (sizeof(ElementType) == 1 && is_ws_transposed_B) { + // Only optimized transpositionB (SW32 for int8 and fp8) can be used + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_128_Atom{})"); + } + } + else { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "Smem selector does not support this element type"); + } + } + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + +// Helper for SS GMMA smem selection that considers a tensor TileShape: +// (BLK_MN, BLK_K) +// or hierarchically +// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) +// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 +template +CUTE_HOST_DEVICE constexpr +auto +ss_smem_selector() +{ + auto BLK_MN0 = size<0>(BLK_MN{}); + auto BLK_K0 = size<0>(BLK_K{}); + + static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); + static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); + + + if constexpr (major == GMMA::Major::MN) { + if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { + return GMMA::Layout_MN_SW128_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { + return GMMA::Layout_MN_SW64_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { + return GMMA::Layout_MN_SW32_Atom{}; + } + else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { + return GMMA::Layout_MN_INTER_Atom{}; + } + else { + static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, + "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); + } + } + else if constexpr (major == GMMA::Major::K) { + if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { + return GMMA::Layout_K_SW32_Atom{}; + } + else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { + return GMMA::Layout_K_INTER_Atom{}; + } + else { + static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, + "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); + } + } +} + +template +constexpr bool +is_input_size_two_bytes() { + return (sizeof(ElementA) == 2 && sizeof(ElementB) == 2); +} + +template +constexpr bool +is_input_fp8() { + return ((cute::is_same_v || cute::is_same_v) && + (cute::is_same_v || cute::is_same_v)); +} + +template +constexpr bool +is_use_rmem_A() { + constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); + constexpr bool IsLayoutAkBk = cutlass::gemm::detail::is_k_major_A() && + cutlass::gemm::detail::is_k_major_B(); + constexpr bool IsUseRmemA = !IsInputSizeTwoBytes && !IsLayoutAkBk; + return IsUseRmemA; +} + +template +constexpr bool +is_aligned() { + return ((sizeof(ElementA) * AlignmentA) % RequiredAlignment == 0) && + ((sizeof(ElementB) * AlignmentB) % RequiredAlignment == 0); +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index 9f922e3c74..7d8f591358 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -33,8 +33,8 @@ #include "cutlass/arch/mma.h" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" -#include "cute/atom/mma_traits_sm90_gmma.hpp" -#include "cute/atom/copy_traits_sm90_tma.hpp" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" // SM90 Collective Builders should be used only starting CUDA 12.0 #if (__CUDACC_VER_MAJOR__ >= 12) @@ -49,124 +49,6 @@ namespace cutlass::gemm::collective { namespace detail { -// -// Some named constants -// -constexpr int tma_alignment_bytes = 16; -constexpr int cp_async_min_alignment_bytes = 4; -constexpr int sm90_smem_capacity_bytes = 232448; - -// Maps 2.x A matrix layout tag to respective GMMA major mode enum -template -constexpr cute::GMMA::Major -gmma_ss_tag_to_major_A() { - // MN major mode is only valid for non-TF32, non-int - if constexpr (cutlass::gemm::detail::is_mn_major_A() && - not cute::is_same_v && - sizeof(ElementA) != 1) { - return cute::GMMA::Major::MN; - } - else { - return cute::GMMA::Major::K; - } -} - -// Maps 2.x B matrix layout tag to respective GMMA major mode enum -template -constexpr cute::GMMA::Major -gmma_ss_tag_to_major_B() { - // MN major mode is only valid for non-TF32, non-int - if constexpr (cutlass::gemm::detail::is_mn_major_B() && - not cute::is_same_v && - sizeof(ElementB) != 1) { - return cute::GMMA::Major::MN; - } - else { - return cute::GMMA::Major::K; - } -} - -template -constexpr cute::GMMA::Major -gmma_rs_tag_to_major_A() { - // MN major mode is only valid for non-TF32 and non-int MMAs - if constexpr (cutlass::gemm::detail::is_mn_major_A()) { - return cute::GMMA::Major::MN; - } - else { - return cute::GMMA::Major::K; - } -} - -template -constexpr cute::GMMA::Major -gmma_rs_tag_to_major_B() { - // MN major mode is only valid for non-TF32 and non-int MMAs - if constexpr (cutlass::gemm::detail::is_mn_major_B()) { - return cute::GMMA::Major::MN; - } - else { - return cute::GMMA::Major::K; - } -} -// Maps a rank-1 cute::Shape<> representing the cluster shape on to the TMA atom that should be used with it -template -constexpr auto -sm90_cluster_shape_to_tma_atom(UnimodalClusterShape unimodal_cluster_shape) { - static_assert(cute::rank(unimodal_cluster_shape) == 1, - "Use this function to figure out TMA for each mode individually."); - - if constexpr (cute::size(unimodal_cluster_shape) == 1) { - return cute::SM90_TMA_LOAD{}; - } - else { - return cute::SM90_TMA_LOAD_MULTICAST{}; - } -} - -// Generates the most efficient possible TiledCopy with cp.async copy atom given a set of parameters. -template -constexpr auto -make_cp_async_gmem_tiled_copy() { - using AlignmentType = cute::uint_byte_t(sizeof(Element)) * Alignment>; - constexpr int TileSizeMN = cute::size(TileMN{}); - constexpr int TileSizeK = cute::size(TileK{}); - - // Maximize the number of threads along the gmem major mode to promote coalesced reads - // While making sure our thread layout tiles the threadblock tile evenly - - if constexpr (cutlass::gemm::detail::is_k_major()) { - // K major thread layout for K major gmem - constexpr int threads_major = TileSizeK / Alignment; - constexpr int threads_minor = ThreadCount / threads_major; - static_assert(threads_major > 0); - static_assert(ThreadCount % threads_major == 0); - static_assert(threads_minor == 0 || (TileSizeMN % threads_minor == 0)); - return make_tiled_copy( - Copy_Atom, Element>{}, - Layout,Int>, - Stride, _1>>{}, - Layout>>{}); - } - else if constexpr (cutlass::gemm::detail::is_mn_major()) { - // MN major thread layout for MN major gmem - constexpr int threads_major = TileSizeMN / Alignment; - constexpr int threads_minor = ThreadCount / threads_major; - static_assert(threads_major > 0); - static_assert(ThreadCount % threads_major == 0); - static_assert(threads_minor == 0 || (TileSizeK % threads_minor == 0)); - return make_tiled_copy( - Copy_Atom, Element>{}, - Layout,Int>, - Stride< _1,Int>>{}, - Layout,_1>>{}); - } - else { - static_assert(cute::is_void_v, "Unsupported gmem layout for automatic gmem tiled copy builder."); - } -} - - // Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. template constexpr int @@ -197,171 +79,6 @@ compute_stage_count_or_override(StageCountAutoCarveout stage_cou return (CapacityBytes - carveout_bytes) / stage_bytes; } -// Helper for SS GMMA smem selection that considers a tensor TileShape: -// (BLK_MN, BLK_K) -// or hierarchically -// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) -// and returns the optimal GMMA::Layout that fits BLK_MN0 and BLK_K0 -template -constexpr auto -rs_smem_selector() { - auto BLK_MN0 = size<0>(BLK_MN{}); - auto BLK_K0 = size<0>(BLK_K{}); - - static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); - static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); - if constexpr (major == GMMA::Major::MN) { - if constexpr (sizeof(ElementType) == 4){ - if constexpr (is_ws_transposed_B) { - // only optimized transpositionB(SW32 and SW128 for tf32) can be used, but prefer SW32 due to free bank conflict - if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { - return GMMA::Layout_MN_SW32_Atom{}; - } - else { - static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0, - "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_SW32_Atom{})"); - } - } - else { - // Fall into SW32 due to free bank conflict - if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { - return GMMA::Layout_MN_SW32_Atom{}; - } - else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { - return GMMA::Layout_MN_INTER_Atom{}; - } - else { - static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, - "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); - } - } - } - // Used for int8, fp16 and bf16 I/O kernels - else if constexpr (sizeof(ElementType) == 1 || sizeof(ElementType) == 2) { - if constexpr (sizeof(ElementType) == 1 && is_ws_transposed_B) { - // Only optimized transpositionB (SW32 for int8) can be used - if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { - return GMMA::Layout_MN_SW128_Atom{}; - } - else { - static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0, - "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_128_Atom{})"); - } - } - else { - if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { - return GMMA::Layout_MN_SW128_Atom{}; - } - else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { - return GMMA::Layout_MN_SW64_Atom{}; - } - else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { - return GMMA::Layout_MN_SW32_Atom{}; - } - else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { - return GMMA::Layout_MN_INTER_Atom{}; - } - else { - static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, - "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); - } - } - } - else { - static_assert(cutlass::detail::dependent_false, "Smem selector does not support this element type"); - } - } - else if constexpr (major == GMMA::Major::K) { - if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } - else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { - return GMMA::Layout_K_SW64_Atom{}; - } - else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { - return GMMA::Layout_K_SW32_Atom{}; - } - else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { - return GMMA::Layout_K_INTER_Atom{}; - } - else { - static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, - "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); - } - } -} - -// Helper for SS GMMA smem selection that considers a tensor TileShape: -// (BLK_MN, BLK_K) -// or hierarchically -// ((BLK_MN0,BLK_MN1,...),(BLK_K0,BLK_K1,...)) -// and returns the largest GMMA::Layout that fits BLK_MN0 and BLK_K0 -template -CUTE_HOST_DEVICE constexpr -auto -ss_smem_selector() -{ - auto BLK_MN0 = size<0>(BLK_MN{}); - auto BLK_K0 = size<0>(BLK_K{}); - - static_assert(BLK_MN0 % 8 == 0, "BLK_MN0 must be a multiple of 8."); - static_assert(BLK_K0 % 8 == 0, "BLK_K0 must be a multiple of 8."); - - - if constexpr (major == GMMA::Major::MN) { - if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW128_Atom{}) == 0) { - return GMMA::Layout_MN_SW128_Atom{}; - } - else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW64_Atom{}) == 0) { - return GMMA::Layout_MN_SW64_Atom{}; - } - else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_SW32_Atom{}) == 0) { - return GMMA::Layout_MN_SW32_Atom{}; - } - else if constexpr (BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0) { - return GMMA::Layout_MN_INTER_Atom{}; - } - else { - static_assert(BLK_MN0 % size<0>(GMMA::Layout_MN_INTER_Atom{}) == 0, - "BLK_MN0 must be a multiple of size<0>(GMMA::Layout_MN_INTER_Atom{})"); - } - } - else if constexpr (major == GMMA::Major::K) { - if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW128_Atom{}) == 0) { - return GMMA::Layout_K_SW128_Atom{}; - } - else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW64_Atom{}) == 0) { - return GMMA::Layout_K_SW64_Atom{}; - } - else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_SW32_Atom{}) == 0) { - return GMMA::Layout_K_SW32_Atom{}; - } - else if constexpr (BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0) { - return GMMA::Layout_K_INTER_Atom{}; - } - else { - static_assert(BLK_K0 % size<1>(GMMA::Layout_K_INTER_Atom{}) == 0, - "BLK_K0 must be a multiple of size<1>(GMMA::Layout_K_INTER_Atom{})"); - } - } -} - -template -constexpr bool -is_input_size_two_bytes() { - return (sizeof(ElementA) == 2 && sizeof(ElementB) == 2); -} - -template -constexpr bool -is_use_rmem_A() { - constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); - constexpr bool IsLayoutAkBk = cutlass::gemm::detail::is_k_major_A() && - cutlass::gemm::detail::is_k_major_B(); - constexpr bool IsUseRmemA = !IsInputSizeTwoBytes && !IsLayoutAkBk; - return IsUseRmemA; -} - template constexpr bool is_swapAB(){ @@ -372,13 +89,6 @@ is_swapAB(){ return SwapAB; } -template -constexpr bool -is_aligned() { - return ((sizeof(ElementA) * AlignmentA) % RequiredAlignment == 0) && - ((sizeof(ElementB) * AlignmentB) % RequiredAlignment == 0); -} - template constexpr bool is_warpspecialized_transpose_B(){ @@ -438,6 +148,8 @@ struct CollectiveBuilder< static_assert(detail::is_aligned(), "Should meet TMA alignment requirement\n"); +static constexpr bool IsFP8Input = detail::is_input_fp8(); + // For fp32 types, map to tf32 MMA value type using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; @@ -461,8 +173,10 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized< - PipelineStages, ClusterShape_MNK, KernelScheduleType>; + /* For FP8 use a separate mainloop compared to other datatypes */ + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>; using SmemCopyAtomA = void; using SmemCopyAtomB = void; @@ -583,6 +297,96 @@ struct CollectiveBuilder< ///////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v || + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // GMMA_TMA_SS @@ -794,16 +598,14 @@ struct CollectiveBuilder< static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); #endif -static constexpr bool IsTmaWarpSpecialized = detail::is_aligned< +static constexpr bool IsTmaCompatible = detail::is_aligned< ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(); #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) - // Cooperative schedule performs best for CUDA Toolkits with version >= 12.1 - - // For TileShape_M == 64, choosing KernelTmaWarpSpecialized as the KernelSchedule - // Since KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 + // Persistent schedules perform best for CUDA Toolkits with version >= 12.1 + // KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 using KernelWarpSpecializedSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, - KernelTmaWarpSpecialized, KernelTmaWarpSpecializedCooperative>; + KernelTmaWarpSpecializedPingpong, KernelTmaWarpSpecializedCooperative>; #else using KernelWarpSpecializedSchedule = KernelTmaWarpSpecialized; #endif @@ -821,7 +623,7 @@ static constexpr bool IsTmaWarpSpecialized = detail::is_aligned< TileShape_MNK, ClusterShape_MNK, StageCountType, - cute::conditional_t + cute::conditional_t >::CollectiveOp; }; diff --git a/include/cutlass/gemm/collective/collective_builder.hpp b/include/cutlass/gemm/collective/collective_builder.hpp index 3c0aa15d45..326e4e121a 100644 --- a/include/cutlass/gemm/collective/collective_builder.hpp +++ b/include/cutlass/gemm/collective/collective_builder.hpp @@ -31,7 +31,7 @@ #pragma once ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "collective_mma.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" namespace cutlass::gemm::collective { @@ -78,5 +78,5 @@ struct CollectiveBuilder { ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "builders/sm90_gmma_builder.inl" +#include "cutlass/gemm/collective/builders/sm90_gmma_builder.inl" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index 2a0ba6da10..de6c77b4c5 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -65,10 +65,11 @@ struct CollectiveMma { ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "sm70_mma_twostage.hpp" -#include "sm80_mma_multistage.hpp" -#include "sm90_mma_multistage_gmma_ss.hpp" -#include "sm90_mma_tma_gmma_ss.hpp" -#include "sm90_mma_tma_gmma_rs_warpspecialized.hpp" -#include "sm90_mma_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm70_mma_twostage.hpp" +#include "cutlass/gemm/collective/sm80_mma_multistage.hpp" +#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/fp8_accumulation.hpp b/include/cutlass/gemm/collective/fp8_accumulation.hpp new file mode 100644 index 0000000000..ea0e8053ed --- /dev/null +++ b/include/cutlass/gemm/collective/fp8_accumulation.hpp @@ -0,0 +1,121 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/algorithm/clear.hpp" +#include "cute/tensor.hpp" + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////FP8 Accumulation/////////////////////////// +////////////////////////////////////////////////////////////////////////////// +/// It would promote (add) the results from the tensor core accumulators to the +/// main accumulators when the number of MMAs reaches the max number of MMA +/// interval specified by user, after that the tensor core accumulators are +/// zeroed. +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +template < + class EngineAccum, + class LayoutAccum> +struct GmmaFP8Accumulation { + using TensorAccum = cute::Tensor; + + static_assert(is_static::value, "Accumulator Layout should be static"); + static_assert(is_rmem::value , "Accumulator tensor must be rmem resident."); + +private: + TensorAccum& accum_; + TensorAccum accum_temp_; + + uint32_t accum_promotion_interval_; // defines the max num of executed MMAs after which accum should be promoted. + uint32_t mma_count_per_mainloop_iteration_; // num of MMAs per k_tile of mainloop + uint32_t mma_count_; // current executed MMAs + uint32_t reset_accum_flag_; // accum needs to be zeroed or not. + + CUTLASS_DEVICE + void promote_core() { + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accum_); ++i) { + accum_(i) += accum_temp_(i); + } + } + +public: + CUTLASS_DEVICE + GmmaFP8Accumulation( + TensorAccum &accum, + uint32_t accum_promotion_interval, + uint32_t mma_count_per_mainloop_iteration) + : accum_(accum), + accum_promotion_interval_(accum_promotion_interval), + mma_count_per_mainloop_iteration_(mma_count_per_mainloop_iteration), + mma_count_(0), + reset_accum_flag_(0) + { + accum_temp_ = cute::make_fragment_like(accum); + } + + CUTLASS_DEVICE + TensorAccum& operator()() { + return accum_temp_; + } + + /// prepare the MMA accumulators when initialization or zeroing is required. + CUTLASS_DEVICE + bool prepare_if_needed() { + return reset_accum_flag_; + } + + /// promote (add) the results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_if_needed() { + mma_count_ += mma_count_per_mainloop_iteration_; + reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0); + if (reset_accum_flag_) { + promote_core(); + mma_count_ = 0; + } + } + + /// promote (add) the residue results from the MMA accumulators to main accumulator if needed. + CUTLASS_DEVICE + void promote_residue_if_needed() { + if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) { + promote_core(); + } + } +}; + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp index 57c079951f..b842eace70 100644 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp @@ -329,7 +329,8 @@ struct CollectiveMma< using ElementB = ElementB_; using StrideB = StrideB_; using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; using GmemTiledCopyA = GmemTiledCopyA_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; using GmemTiledCopyB = GmemTiledCopyB_; using SmemLayoutAtomA = SmemLayoutAtomA_; using SmemLayoutAtomB = SmemLayoutAtomB_; @@ -387,6 +388,14 @@ struct CollectiveMma< return args; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + /// Perform a collective-scoped matrix multiply-accumulate template < class FrgTensorD, @@ -399,8 +408,8 @@ struct CollectiveMma< CUTLASS_DEVICE void operator() ( FrgTensorD &accum, - TensorA gA, - TensorB gB, + TensorA gA_in, + TensorB gB_in, FrgTensorC const &src_accum, KTileIterator k_tile_iter, int k_tile_count, ResidueMNK residue_mnk, @@ -432,8 +441,8 @@ struct CollectiveMma< // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) // This aligns the tensor with BLK_K for all but the 0th k_tile - gA.data() = &gA(0, get<2>(residue_mnk), 0); - gB.data() = &gB(0, get<2>(residue_mnk), 0); + Tensor gA = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA_in); + Tensor gB = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB_in); // Partition the copying of A and B tiles across the threads GmemTiledCopyA gmem_tiled_copy_a; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index faf2857acd..94f4656f5a 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -44,6 +44,8 @@ #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { @@ -106,7 +108,7 @@ struct CollectiveMma< using SmemCopyAtomA = SmemCopyAtomA_; using SmemCopyAtomB = SmemCopyAtomB_; - // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8 WGMMA) + // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8, Fp8 WGMMA) static constexpr bool IsLayoutAkBmn = cute::is_same_v, layout::RowMajor> && cute::is_same_v, layout::RowMajor>; @@ -117,7 +119,17 @@ struct CollectiveMma< using InternalSmemLayoutAtomB = cute::conditional_t; using InternalSmemCopyAtomA = cute::conditional_t; using InternalSmemCopyAtomB = cute::conditional_t; - + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + using TransformA = TransformA_; using TransformB = TransformB_; using ArchTag = typename DispatchPolicy::ArchTag; @@ -137,24 +149,26 @@ struct CollectiveMma< static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - // Tile along K mode first before tiling over MN. PIPE mode last as usual. - // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( InternalSmemLayoutAtomA{}, make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - Step<_2,_1,_3>{})); + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); using SmemLayoutB = decltype(tile_to_shape( InternalSmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - Step<_2,_1,_3>{})); + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major only (e.g. tf32, fp32, fp8, int8). static constexpr bool IsLayoutAmnBmn = cute::is_same_v, layout::ColumnMajor> && cute::is_same_v, layout::RowMajor>; static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; + using TransposeOperandB = decltype(cutlass::transform::collective::detail::make_transpose_operand_b( + 0, 0, TiledMma{}, SmemLayoutB{}, InternalSmemLayoutAtomB{}, + InternalElementB{}, cute::bool_constant{})); - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); static_assert(not cute::is_base_of::value && cute::is_base_of::value, "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); @@ -163,17 +177,6 @@ struct CollectiveMma< static_assert(cute::is_same_v || cute::is_same_v, "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - // TMA converts f32 input to tf32 when copying from GMEM to SMEM - // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. - static constexpr bool ConvertF32toTF32A = cute::is_same_v; - static constexpr bool ConvertF32toTF32B = cute::is_same_v; - using ConvertedElementA = cute::conditional_t>>; - using ConvertedElementB = cute::conditional_t>>; - using InternalElementA = cute::conditional_t; - using InternalElementB = cute::conditional_t; - using InternalStrideA = cute::conditional_t; - using InternalStrideB = cute::conditional_t; - using GmmaSmemLayoutAtomB = decltype(transform::collective::detail::gmma_smem_transpose_or_passthrough< TransposeB, InternalSmemLayoutAtomB, InternalElementB>()); @@ -181,23 +184,21 @@ struct CollectiveMma< using GmmaSmemLayoutB = decltype(tile_to_shape( GmmaSmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - Step<_2,_1,_3>{})); + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); - static_assert(TransposeB || (cute::is_same_v), + static_assert(TransposeB xor (cute::is_same_v), "Should be same layout if not TransposeB."); static_assert(!TransposeB || size<1>(SmemLayoutB{}) * sizeof(InternalElementB) == 128, "SmemLayoutB K must be 128bytes to be transposed."); static_assert(!transform::collective::detail::use_universal_transposition(), "Warp specialized ARF kernels have not supported universal B transposition yet."); - static_assert(!TransposeB || !cute::is_same_v, - "Transpose RS kernel requires kernel schedule schmem is not KernelTmaWarpSpecializedCooperative."); struct SharedStorage { - struct TensorStorage : cute::aligned_struct<128> { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; + struct TensorStorage : cute::aligned_struct<256> { + cute::array_aligned, 256> smem_A; + cute::array_aligned, 256> smem_B; } tensors; using PipelineStorage = typename MainloopPipeline::SharedStorage; @@ -212,6 +213,7 @@ struct CollectiveMma< StrideA dA; ElementB const* ptr_B; StrideB dB; + uint32_t mma_promotion_interval = 4; }; // Device side kernel params @@ -243,12 +245,9 @@ struct CollectiveMma< to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; - // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; if constexpr (SwapAB) { M = get<1>(problem_shape_MNKL); @@ -293,6 +292,27 @@ struct CollectiveMma< }; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; static_assert(K_PIPE_MMAS == 0, "no MMA stage should be asynchronous for this mainloop for now."); @@ -323,11 +343,12 @@ struct CollectiveMma< TensorB const& gB, TMA_LOAD_B& tma_load_b, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, + uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { using namespace cute; - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); @@ -338,8 +359,10 @@ struct CollectiveMma< // // Prepare the TMA loads for A and B // + + constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - dim3 cluster_local_block_id = cute::block_id_in_cluster(); auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); @@ -397,7 +420,7 @@ struct CollectiveMma< CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); @@ -439,12 +462,13 @@ struct CollectiveMma< "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); // Obtain warp index - int warp_idx = canonical_warp_idx(); - int warp_idx_in_warp_group = warp_idx % 4; - int warp_group_thread_idx = thread_idx % 128; + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_M,BLK_K,PIPE) // If TransposeB, GMMA will read from transposed B layout SMEM Tensor gmma_sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), GmmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) @@ -462,13 +486,16 @@ struct CollectiveMma< Tensor tCsB = thread_mma.partition_B(gmma_sB); // (MMA,MMA_N,MMA_K,PIPE) Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - // - // Copy Atom A retiling - // + // + // Copy Atom A retiling + // + - auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); - auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); - Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K @@ -488,81 +515,156 @@ struct CollectiveMma< // We release buffers to producer warps(dma load) with some mmas in flight PipelineState smem_pipe_release = smem_pipe_read; - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + TransposeOperandB transpose = cutlass::transform::collective::detail::make_transpose_operand_b( + warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, + InternalSmemLayoutAtomB{}, InternalElementB{}, + cute::bool_constant{}); + warpgroup_fence_operand(accum); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + // first k tile + { pipeline.consumer_wait(smem_pipe_read); + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + + bool skip_wait = (pipeline.consumer_try_wait(smem_pipe_read) == BarrierStatus::WaitDone); + // copy smem->rmem for A operand - copy(smem_tiled_copy_A, tCsA(_,_,_,smem_pipe_read.index()), tCrA_copy_view); + copy(smem_tiled_copy_A, tCsA(_,_,0,read_stage), tCrA_copy_view(_,_,0)); // transpose B operand in SMEM - if constexpr (TransposeB) { - transform::collective::detail::transpose_b_operand( - sB, gmma_sB, smem_pipe_read, warp_idx_in_warp_group, warp_group_thread_idx, - tiled_mma, SmemLayoutB{}, InternalSmemLayoutAtomB{}, InternalElementB{}); - } // if TransposeB + transpose(sB, gmma_sB, read_stage, 0); - int read_stage = smem_pipe_read.index(); - warpgroup_arrive(); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + transpose.synchronize(); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); } - warpgroup_commit_batch(); + warpgroup_wait<2>(); + + + if (k_tile_count - 1 > 0) { + if (!skip_wait) { + pipeline.consumer_wait(smem_pipe_read); + } + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } - ++smem_pipe_read; + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); } + warpgroup_fence_operand(accum); // Mainloop GMMAs - k_tile_count -= prologue_mma_count; + --k_tile_count; CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > 0; --k_tile_count) { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - pipeline.consumer_wait(smem_pipe_read); + for ( ; k_tile_count > 1; --k_tile_count) { // // Compute on k_tile // - // copy smem->rmem for A operand - copy(smem_tiled_copy_A, tCsA(_,_,_,smem_pipe_read.index()), tCrA_copy_view); - // transpose B operand in SMEM - if constexpr (TransposeB) { - transform::collective::detail::transpose_b_operand( - sB, gmma_sB, smem_pipe_read, warp_idx_in_warp_group, warp_group_thread_idx, - tiled_mma, SmemLayoutB{}, InternalSmemLayoutAtomB{}, InternalElementB{}); - } // if TransposeB int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + bool skip_wait = (pipeline.consumer_try_wait(smem_pipe_read) == BarrierStatus::WaitDone); + warpgroup_fence_operand(accum); - warpgroup_arrive(); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { - // (V,M,K) x (V,N,K) => (V,M,N) + if (k_block == size<2>(tCrA) - 1) { + if (!skip_wait) { + pipeline.consumer_wait(smem_pipe_read); + } + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } else { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + // transpose B operand in SMEM + if (k_block < 2) { + transpose.synchronize(k_block); // make transpose of k_block available + } + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); warpgroup_fence_operand(accum); - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + } - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; + warpgroup_fence_operand(accum); + + if (k_tile_count > 0) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + if (k_block < 2) { + transpose.synchronize(k_block); // make k_block transpose available + } + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + warpgroup_fence_operand(accum); } warpgroup_fence_operand(accum); @@ -572,7 +674,7 @@ struct CollectiveMma< CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + int prologue_mma_count = 1; k_tile_count -= prologue_mma_count; smem_pipe_release.advance(k_tile_count); diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index cf0a050dcd..932765ea7a 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -41,6 +41,7 @@ #include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -120,16 +121,15 @@ struct CollectiveMma< static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - // Tile along K mode first before tiling over MN. PIPE mode last as usual. - // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - Step<_2,_1,_3>{})); + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - Step<_2,_1,_3>{})); + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); static_assert(cute::is_base_of::value && @@ -162,6 +162,7 @@ struct CollectiveMma< StrideA dA; ElementB const* ptr_B; StrideB dB; + uint32_t mma_promotion_interval = 4; }; // Device side kernel params @@ -193,12 +194,9 @@ struct CollectiveMma< to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; - // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; auto ptr_A = reinterpret_cast(args.ptr_A); auto ptr_B = reinterpret_cast(args.ptr_B); @@ -223,6 +221,27 @@ struct CollectiveMma< }; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) @@ -245,6 +264,7 @@ struct CollectiveMma< FrgTensorC& accum, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, + uint32_t block_rank_in_cluster, char* shared_memory, Params const& mainloop_params) { @@ -267,7 +287,10 @@ struct CollectiveMma< // // Prepare the TMA loads for A and B // - dim3 cluster_local_block_id = cute::block_id_in_cluster(); + + constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); @@ -303,7 +326,7 @@ struct CollectiveMma< // Obtain warp index - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; PipelineParams params; diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index 39c1a17aa4..c7dee7b1d4 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -41,6 +41,7 @@ #include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -121,16 +122,15 @@ struct CollectiveMma< static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - // Tile along K mode first before tiling over MN. PIPE mode last as usual. - // This maximizes TMA boxes due to better smem-K vectorization, reducing total issued TMAs. + // Tile along modes in a way that maximizes the TMA box size. using SmemLayoutA = decltype(tile_to_shape( SmemLayoutAtomA{}, make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - Step<_2,_1,_3>{})); + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - Step<_2,_1,_3>{})); + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); static_assert(cute::is_base_of::value && @@ -167,6 +167,7 @@ struct CollectiveMma< StrideA dA; ElementB const* ptr_B; StrideB dB; + uint32_t mma_promotion_interval = 4; }; // Device side kernel params @@ -175,14 +176,14 @@ struct CollectiveMma< using TMA_A = decltype(make_tma_copy( GmemTiledCopyA{}, make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_,_,0), + SmemLayoutA{}(_,_,cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any // Assumption: StrideB is congruent with Problem_NK using TMA_B = decltype(make_tma_copy( GmemTiledCopyB{}, make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_,_,0), + SmemLayoutB{}(_,_,cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any TMA_A tma_load_a; @@ -198,12 +199,9 @@ struct CollectiveMma< to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { (void) workspace; - // Optionally append _1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; auto ptr_A = reinterpret_cast(args.ptr_A); auto ptr_B = reinterpret_cast(args.ptr_B); @@ -228,6 +226,27 @@ struct CollectiveMma< }; } + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; static constexpr uint32_t TmaTransactionBytes = @@ -257,11 +276,12 @@ struct CollectiveMma< TensorB const& gB, TMA_LOAD_B& tma_load_b, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, + uint32_t block_rank_in_cluster, TensorStorage& shared_tensors) { using namespace cute; - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); @@ -273,7 +293,9 @@ struct CollectiveMma< // Prepare the TMA loads for A and B // - dim3 cluster_local_block_id = cute::block_id_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); @@ -334,7 +356,7 @@ struct CollectiveMma< MainloopPipeline pipeline, PipelineState smem_pipe_write) { - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int warp_idx_in_warp_group = warp_idx % 4; int lane_predicate = cute::elect_one_sync(); @@ -417,7 +439,8 @@ struct CollectiveMma< for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - pipeline.consumer_wait(smem_pipe_read); + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); int read_stage = smem_pipe_read.index(); warpgroup_arrive(); @@ -442,7 +465,8 @@ struct CollectiveMma< for ( ; k_tile_count > 0; --k_tile_count) { // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - pipeline.consumer_wait(smem_pipe_read); + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); // // Compute on k_tile diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 0000000000..0e16027139 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,537 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedFP8, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync< + DispatchPolicy::Stages, + typename DispatchPolicy::ClusterShape>; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,0), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,0), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, + tma_load_b, + args.mma_promotion_interval + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA instructions. */ + implementable = implementable && (args.mma_promotion_interval % 4 == 0); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(ElementA)))+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TMA_LOAD_A, + class TensorB, class TMA_LOAD_B, + class KTileIterator + > + CUTLASS_DEVICE void + load( + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + TensorA const& gA, TMA_LOAD_A& tma_load_a, + TensorB const& gB, TMA_LOAD_B& tma_load_b, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + + using namespace cute; + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) + { + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (warp_idx_in_warp_group == 0 and lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) + { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation(accum, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation.prepare_if_needed()) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accumulation()); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation()); + + accumulation.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/default_gemm_configuration.h b/include/cutlass/gemm/device/default_gemm_configuration.h index f4ce58514b..7b95d01963 100644 --- a/include/cutlass/gemm/device/default_gemm_configuration.h +++ b/include/cutlass/gemm/device/default_gemm_configuration.h @@ -496,7 +496,7 @@ struct DefaultGemmConfiguration::value, ElementAccumulator, + ElementC, 1, ElementAccumulator, ElementAccumulator>; using Operator = arch::OpMultiplyAdd; @@ -777,7 +777,7 @@ struct DefaultGemmConfiguration::value, ElementAccumulator, + ElementC, 1, ElementAccumulator, ElementAccumulator>; using Operator = arch::OpMultiplyAdd; diff --git a/include/cutlass/gemm/device/gemm_splitk_parallel.h b/include/cutlass/gemm/device/gemm_splitk_parallel.h index 55db9552dd..aa45f4b447 100644 --- a/include/cutlass/gemm/device/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/device/gemm_splitk_parallel.h @@ -250,9 +250,6 @@ class GemmSplitKParallel { /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { - - // TODO - return Status::kSuccess; } diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 04ead3c39c..896bff187d 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -38,10 +38,14 @@ // common #include "cutlass/cutlass.h" -#include "cutlass/trace.h" -#include "cutlass/cluster_launch.hpp" #include "cutlass/device_kernel.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) // 2.x #include "cutlass/gemm/device/gemm_universal_base.h" @@ -107,16 +111,14 @@ class GemmUniversalAdapter< // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 using MathOperator = cutlass::arch::OpMultiplyAdd; - // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! + // All tensorop operations have atom shape's M >= 8 using OperatorClass = cute::conditional_t< - (cute::size(typename GemmKernel::TiledMma::AtomThrID{}) > 1), - cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}) >= 8, + cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; using ArchTag = typename GemmKernel::ArchTag; // NOTE: Assume identity swizzle for now - static_assert(cute::is_void_v, - "CUTLASS 3.x kernel types do not support grid swizzle functors yet."); using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape @@ -155,13 +157,13 @@ class GemmUniversalAdapter< static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; // Inspect TiledCopy for A and B to compute the alignment size - static int constexpr kAlignmentA = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); - static int constexpr kAlignmentB = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); - static int constexpr kAlignmentC = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); - static int constexpr kAlignmentD = gemm::detail::get_alignment_count_from_gmem_tiled_copy< + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; @@ -182,6 +184,11 @@ class GemmUniversalAdapter< public: + /// Access the Params structure + Params const& params() const { + return params_; + } + /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const& args) { @@ -268,24 +275,10 @@ class GemmUniversalAdapter< CUTLASS_TRACE_HOST("GemmUniversal::initialize() - workspace " << workspace << ", stream: " << (stream ? "non-null" : "null")); - size_t workspace_bytes = GemmKernel::get_workspace_size(args); - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - if (workspace_bytes) { - if (!workspace) { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - return Status::kErrorWorkspaceNull; - } - - if (args.mode == GemmUniversalMode::kGemm) { - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - if (cudaSuccess != result) { - result = cudaGetLastError(); // to clear the error bit - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } + // Initialize the workspace + Status status = GemmKernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; } // Initialize the Params structure diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index cdedd3fe4d..204835fe8f 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -103,16 +103,16 @@ class GemmUniversalBase { // // Device ordinal - thread_local static int device_ordinal_; + CUTLASS_THREAD_LOCAL static int device_ordinal_; /// Device SM count - thread_local static int device_sms_; + CUTLASS_THREAD_LOCAL static int device_sms_; /// Kernel SM occupancy (in thread blocks) - thread_local static int sm_occupancy_; + CUTLASS_THREAD_LOCAL static int sm_occupancy_; /// Kernel dynamic shared memory allocation requirement - thread_local static int smem_size_; + CUTLASS_THREAD_LOCAL static int smem_size_; /// Initialize static thread-local members for the thread's current device, /// if necessary. @@ -323,7 +323,11 @@ class GemmUniversalBase { } // Assign and prepare workspace memory - return params_.init_workspace(workspace, stream); + if (args.mode == GemmUniversalMode::kGemm) { + return params_.init_workspace(workspace, stream); + } + + return Status::kSuccess; } @@ -394,19 +398,19 @@ class GemmUniversalBase { /// Device ordinal template -thread_local int GemmUniversalBase::device_ordinal_ = -1; +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_ordinal_ = -1; /// Device SM count template -thread_local int GemmUniversalBase::device_sms_ = -1; +CUTLASS_THREAD_LOCAL int GemmUniversalBase::device_sms_ = -1; /// Kernel SM occupancy (in thread blocks) template -thread_local int GemmUniversalBase::sm_occupancy_ = -1; +CUTLASS_THREAD_LOCAL int GemmUniversalBase::sm_occupancy_ = -1; /// Kernel dynamic shared memory allocation requirement template -thread_local int GemmUniversalBase::smem_size_ = -1; +CUTLASS_THREAD_LOCAL int GemmUniversalBase::smem_size_ = -1; diff --git a/include/cutlass/gemm/device/gemv.h b/include/cutlass/gemm/device/gemv.h index 2cd3014e04..b2131d2f76 100644 --- a/include/cutlass/gemm/device/gemv.h +++ b/include/cutlass/gemm/device/gemv.h @@ -77,12 +77,6 @@ class Gemv { static int const kThreadCount = GemvKernel::kThreadCount; static int const kThreadsPerRow = GemvKernel::kThreadsPerRow; - static int const kStages = GemvKernel::kStages; - - static int const kAlignmentA = GemvKernel::kAlignmentA; - static int const kAlignmentB = GemvKernel::kAlignmentB; - static int const kAlignmentC = GemvKernel::kAlignmentC; - using Arguments = typename GemvKernel::Arguments; using Params = typename GemvKernel::Params; diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index aee918ee26..f122fe0384 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -31,10 +31,10 @@ #pragma once #include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" #include "cute/layout.hpp" #include "cute/numeric/integral_constant.hpp" - ////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm { @@ -51,6 +51,11 @@ struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperative { }; +// FP8 related policies (including Fast Accumulation) +struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; +struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; + // Policies for dispatch of epilogue struct EpilogueDefault { }; struct EpilogueTransposed { }; @@ -165,6 +170,22 @@ struct MainloopSm90TmaGmmaRmemAWarpSpecialized { "KernelSchedule must be one of the warp specialized policies"); }; +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule +// For FP8 kernels +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecialized +> +struct MainloopSm90TmaGmmaWarpSpecializedFP8 + : MainloopSm90TmaGmmaWarpSpecialized { + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); +}; + ////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::gemm diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index cbb18d709f..ec90721376 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -35,9 +35,11 @@ #include "cutlass/cutlass.h" #include "cutlass/coord.h" +#include "cutlass/gemm_coord.h" #include "cutlass/layout/matrix.h" #include "cute/layout.hpp" -#include "cute/arch/copy_sm90_tma.hpp" +#include "cutlass/detail/layout.hpp" + namespace cutlass { namespace gemm { @@ -53,356 +55,6 @@ enum class Operand { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Shape of a matrix multiply-add operation -template < - /// Rows of matrix product - int M = 1, - /// Columns of matrix product - int N = 1, - /// Inner dimension of matrix product - int K = 1 -> -struct GemmShape { - static int const kM = M; - static int const kN = N; - static int const kK = K; - - static int const kMN = M * N; - static int const kMK = M * K; - static int const kKN = N * K; - static int const kMNK = M * N * K; - - static int const kCount = kMNK; - - // - // Static member functions - // - - /// Returns a Coord object - CUTLASS_HOST_DEVICE - static Coord<3> toCoord() { - return make_Coord(kM, kN, kK); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Type alias of the transpose of a GemmShape -template < - /// concept: GemmShape - typename Shape -> -using GemmShapeTranspose = GemmShape; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// GemmCoord is a structure derived from Coord<3> that specifies a location within the -/// coordinate space of a GEMM problem. -struct GemmCoord : public Coord<3, int> { - - /// Integer-valued index - typedef int Index; - - /// Base type is a Coord of rank=3 - typedef Coord<3, Index> Base; - - /// GEMM M dimension - rows of the output C matrix - static int const kM = 0; - - /// GEMM N dimension - columns of the output C matrix - static int const kN = 1; - - /// GEMM K dimension - inner dimension of the GEMM problem - static int const kK = 2; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - GemmCoord() { } - - /// Constructs from Coord<3> and a batch - CUTLASS_HOST_DEVICE - GemmCoord(Coord<3, Index> const &coord): Base(make_Coord(coord[0], coord[1], coord[2])) { } - - /// Helper to construct from a K, N, M, batch variables - CUTLASS_HOST_DEVICE - GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { } - - /// Returns the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index const & m() const { return this->at(kM); } - - /// Returns reference to the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index & m() { return this->at(kM); } - - /// Returns the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index const & n() const { return this->at(kN); } - - /// Returns reference to the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index & n() { return this->at(kN); } - - /// Returns the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index const & k() const { return this->at(kK); } - - /// Returns reference to the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index & k() { return this->at(kK); } - - /// Obtains a Coord<3> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<3> mnk() const { - return make_Coord(m(), n(), k()); - } - - /// Obtains a Coord<3> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<3> knm() const { - return make_Coord(k(), n(), m()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> nm() const { - return make_Coord(n(), m()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> mn() const { - return make_Coord(m(), n()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> mk() const { - return make_Coord(m(), k()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> km() const { - return make_Coord(k(), m()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> nk() const { - return make_Coord(n(), k()); - } - - /// Obtains a Coord<2> from GemmCoord - CUTLASS_HOST_DEVICE - Coord<2> kn() const { - return make_Coord(k(), n()); - } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - GemmCoord operator+(Base const& b) const { - return GemmCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - GemmCoord operator-(Base const& b) const { - return GemmCoord(Base::operator-(b)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - GemmCoord operator*(Base const& b) const { - return GemmCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - GemmCoord operator/(Base const& b) const { - return GemmCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - GemmCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - GemmCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - GemmCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - GemmCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/// BatchedGemmCoord is a structure derived from Coord<4> that specifies a location within the -/// coordinate space of a batched GEMM problem. -struct BatchedGemmCoord : public Coord<4, int> { - - /// Integer-valued index - typedef int Index; - - /// Base type is a Coord of rank=4 - typedef Coord<4, Index> Base; - - /// GEMM M dimension - rows of the output C matrix - static int const kM = 0; - - /// GEMM N dimension - columns of the output C matrix - static int const kN = 1; - - /// GEMM K dimension - inner dimension of the GEMM problem - static int const kK = 2; - - /// GEMM Batch dimension - inner dimension of the GEMM problem - static int const kBatch = 3; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - BatchedGemmCoord() { } - - /// Constructs from Coord<4> - CUTLASS_HOST_DEVICE - BatchedGemmCoord(Base const &coord): Base(coord) { } - - /// Helper to construct from a K, N, M, and batch variables - CUTLASS_HOST_DEVICE - BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { } - - /// Returns the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index const & m() const { return this->at(kM); } - - /// Returns reference to the GEMM M coordinate - CUTLASS_HOST_DEVICE - Index & m() { return this->at(kM); } - - /// Returns the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index const & n() const { return this->at(kN); } - - /// Returns reference to the GEMM N coordinate - CUTLASS_HOST_DEVICE - Index & n() { return this->at(kN); } - - /// Returns the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index const & k() const { return this->at(kK); } - - /// Returns reference to the GEMM K coordinate - CUTLASS_HOST_DEVICE - Index & k() { return this->at(kK); } - - /// Returns the GEMM batch coordinate - CUTLASS_HOST_DEVICE - Index const & batch() const { return this->at(kBatch); } - - /// Returns reference to the GEMM batch coordinate - CUTLASS_HOST_DEVICE - Index & batch() { return this->at(kBatch); } - - /// Obtains a GemmCoord from BatchedGemmCoord - CUTLASS_HOST_DEVICE - GemmCoord mnk() const { - return GemmCoord(m(), n(), k()); - } - - /// Obtains a Coord<4> from BatchedGemmCoord - CUTLASS_HOST_DEVICE - Coord<4> mnkb() const { - return make_Coord(m(), n(), k(), batch()); - } - - // - // Coord operators - // - - /// Element-wise addition - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator+(Base const& b) const { - return BatchedGemmCoord(Base::operator+(b)); - } - - /// Element-wise subtraction - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator-(Base const& b) const { - return BatchedGemmCoord(Base::operator-(b)); - } - - /// Element-wise multiplication - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator*(Base const& b) const { - return BatchedGemmCoord(Base::operator*(b)); - } - - /// Element-wise division - CUTLASS_HOST_DEVICE - BatchedGemmCoord operator/(Base const& b) const { - return BatchedGemmCoord(Base::operator/(b)); - } - - /// In-place addition - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator+=(Base const& b) { - Base::operator+=(b); - return *this; - } - - /// In-place subtraction - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator-=(Base const& b) { - Base::operator-=(b); - return *this; - } - - /// In-place multiplication - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator*=(Base const& b) { - Base::operator*=(b); - return *this; - } - - /// In-place division - CUTLASS_HOST_DEVICE - BatchedGemmCoord& operator/=(Base const& b) { - Base::operator/=(b); - return *this; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - enum class GemmUniversalMode { kGemm, kGemmSplitKParallel, @@ -422,160 +74,41 @@ enum class SharedMemoryClearOption { //////////////////////////////////////////////////////////////////////////////////////////////////// -// For each cutlass::layout, provides its corresponding cute stride types, 64b by default - -template -struct TagToStrideA { - using type = L; -}; - -// Maps to modes [M, K, L] -template <> -struct TagToStrideA { - using type = cute::Stride, int64_t>; - using tag = layout::RowMajor; -}; - -// Maps to modes [M, K, L] -template <> -struct TagToStrideA { - using type = cute::Stride, int64_t, int64_t>; - using tag = layout::ColumnMajor; -}; - -template -struct TagToStrideB { - using type = L; -}; - -// Maps to modes [N, K, L] -template <> -struct TagToStrideB { - using type = cute::Stride, int64_t, int64_t>; - using tag = layout::RowMajor; -}; - -// Maps to modes [N, K, L] -template <> -struct TagToStrideB { - using type = cute::Stride, int64_t>; - using tag = layout::ColumnMajor; -}; - - -// Maps to modes [N, N, L] -template -struct TagToStrideC : TagToStrideA { }; - -// Convenience aliases -template -using TagToStrideA_t = typename TagToStrideA::type; - -template -using TagToStrideB_t = typename TagToStrideB::type; - -template -using TagToStrideC_t = typename TagToStrideC::type; +using cutlass::detail::TagToStrideA; +using cutlass::detail::TagToStrideB; +using cutlass::detail::TagToStrideC; +using cutlass::detail::TagToStrideA_t; +using cutlass::detail::TagToStrideB_t; +using cutlass::detail::TagToStrideC_t; //////////////////////////////////////////////////////////////////////////////////////////////////// -// For 2.x compatibility APIs, provide stride->layout tag mappers namespace detail { -template -constexpr bool -is_mn_major() { - // Account for stride types with and without batch mode and batch modes with static zero stride - return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value; -} - -// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices -template -constexpr -auto -stride_to_layout_tag_A() { - if constexpr (is_mn_major()) { // M major - return layout::ColumnMajor{}; - } - else { // K major - return layout::RowMajor{}; - } - - CUTE_GCC_UNREACHABLE; -} - -template -constexpr -auto -stride_to_layout_tag_B() { - if constexpr (is_mn_major()) { // N major - return layout::RowMajor{}; - } - else { // K major - return layout::ColumnMajor{}; - } +using cutlass::detail::StrideToLayoutTagA; +using cutlass::detail::StrideToLayoutTagB; +using cutlass::detail::StrideToLayoutTagC; +using cutlass::detail::StrideToLayoutTagA_t; +using cutlass::detail::StrideToLayoutTagB_t; +using cutlass::detail::StrideToLayoutTagC_t; - CUTE_GCC_UNREACHABLE; +template +constexpr bool +is_major(Stride = {}) { + return ::cutlass::detail::is_major(); } -// Inspects a TiledCopy and returns its alignment in terms of element count -template -constexpr int -get_alignment_count_from_gmem_tiled_copy() { - if constexpr (cute::is_void_v) { - return 1; - } - - // Account for ElementC = void kernels - else if constexpr (cute::is_void_v) { - return 0; - } - - else { - // For TMA tiled copies, we know the alignment has to be 128 bits - if constexpr ( cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - ) { - return 128 / sizeof_bits::value; - } - else { - // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN - return GmemTiledCopy::NumValSrc; - } - } +template +constexpr bool +is_mn_major() { + return is_major<0,Stride>(); } -// Utilities to map Stride back on to their corresponding layout tags -template -struct StrideToLayoutTagA { - using type = decltype(detail::stride_to_layout_tag_A()); -}; - -template -struct StrideToLayoutTagB { - using type = decltype(detail::stride_to_layout_tag_B()); -}; - -// Maps to modes [N, N, L] -template -struct StrideToLayoutTagC : StrideToLayoutTagA { }; - -// Convenience aliases -template -using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; - -template -using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; - -template -using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; - template constexpr bool is_k_major() { - return ! is_mn_major(); + return is_major<1,Stride>(); } template @@ -605,7 +138,7 @@ is_k_major_B() { /////////////////////////////////////////////////////////////////////////////// // The following two metafunctions are used to detect whether a `kernel::Gemm` or `kernel::GemmUniversal` -// is implementing the CUTLASS 3.x API or not, by checking if the problem shape type is aliased within or not. +// is implementing the CUTLASS 3.x API or not, by checking if the problem shape type is aliased within or not. template struct IsCutlass3GemmKernel : cute::false_type { }; diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index b5064ec7cf..1d2c024b7f 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -256,7 +256,7 @@ struct Gemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_array.h b/include/cutlass/gemm/kernel/gemm_array.h index 1862e206fd..464c355eea 100644 --- a/include/cutlass/gemm/kernel/gemm_array.h +++ b/include/cutlass/gemm/kernel/gemm_array.h @@ -193,7 +193,7 @@ struct GemmArray { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_batched.h b/include/cutlass/gemm/kernel/gemm_batched.h index 464aeef51d..fcb4ec2d5c 100644 --- a/include/cutlass/gemm/kernel/gemm_batched.h +++ b/include/cutlass/gemm/kernel/gemm_batched.h @@ -204,7 +204,7 @@ struct GemmBatched { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_grouped.h b/include/cutlass/gemm/kernel/gemm_grouped.h index 84dc4aeec9..310ff3b1d8 100644 --- a/include/cutlass/gemm/kernel/gemm_grouped.h +++ b/include/cutlass/gemm/kernel/gemm_grouped.h @@ -395,7 +395,7 @@ struct GemmGrouped { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h index c2daadbfcd..3fe842a040 100644 --- a/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h +++ b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h @@ -304,14 +304,18 @@ struct GemmLayernormMainloopFusion { ThreadblockShape, ElementA, ElementB, - ElementC> + ElementC, + LayoutA, + LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, - ElementC>; + ElementC, + LayoutA, + LayoutB>; // // Data members diff --git a/include/cutlass/gemm/kernel/gemm_pipelined.h b/include/cutlass/gemm/kernel/gemm_pipelined.h index df450d08c7..900e04428f 100644 --- a/include/cutlass/gemm/kernel/gemm_pipelined.h +++ b/include/cutlass/gemm/kernel/gemm_pipelined.h @@ -111,7 +111,7 @@ __global__ void GemmPipelined( tb_thread_id, tb_offset_B); - int warp_id = canonical_warp_idx(); + int warp_id = canonical_warp_idx_sync(); int lane_id = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h index 92243a97fe..6987d7e691 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -243,14 +243,18 @@ struct GemmPlanarComplex { ThreadblockShape, ElementA, ElementB, - ElementC> + ElementC, + LayoutA, + LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, - ElementC>; + ElementC, + LayoutA, + LayoutB>; // // Data members @@ -533,7 +537,7 @@ struct GemmPlanarComplex { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index 713946f01b..6a3aa11c1d 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -231,14 +231,18 @@ struct GemmPlanarComplexArray { ThreadblockShape, ElementA, ElementB, - ElementC> + ElementC, + LayoutA, + LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, - ElementC>; + ElementC, + LayoutA, + LayoutB>; // // Data members @@ -466,7 +470,7 @@ struct GemmPlanarComplexArray { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index 116d83db1c..8f146afbc2 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -256,14 +256,18 @@ class GemmUniversal< ThreadblockShape, ElementA, ElementB, - ElementC> + ElementC, + LayoutA, + LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, - ElementC>; + ElementC, + LayoutA, + LayoutB>; // // Data members @@ -542,7 +546,7 @@ class GemmUniversal< // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 7bee6bbdb6..4e046ddd3e 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -30,6 +30,8 @@ **************************************************************************************************/ #pragma once +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::kernel { @@ -54,7 +56,7 @@ template < class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, - class GridSwizzle_ = void, + class TileScheduler_ = void, class Enable = void > class GemmUniversal; diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h index 5ef25d78a8..1c58b44ef4 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -228,14 +228,18 @@ struct GemmWithFusedEpilogue { ThreadblockShape, ElementA, ElementB, - ElementC> + ElementC, + LayoutA, + LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, - ElementC>; + ElementC, + LayoutA, + LayoutB>; // // Data members @@ -955,14 +959,18 @@ struct GemmWithFusedEpilogue { ThreadblockShape, ElementA, ElementB, - ElementC> + ElementC, + LayoutA, + LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, - ElementC>; + ElementC, + LayoutA, + LayoutB>; // // Data members @@ -1235,7 +1243,7 @@ struct GemmWithFusedEpilogue { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h index c9195e3af4..863b0c4c29 100644 --- a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h +++ b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h @@ -198,14 +198,18 @@ struct GemmWithKReduction { ThreadblockShape, ElementA, ElementB, - ElementC> + ElementC, + LayoutA, + LayoutB> { using ParamsBase = UniversalParamsBase< ThreadblockSwizzle, ThreadblockShape, ElementA, ElementB, - ElementC>; + ElementC, + LayoutA, + LayoutB>; // // Data members @@ -510,7 +514,7 @@ struct GemmWithKReduction { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/gemv.h b/include/cutlass/gemm/kernel/gemv.h index 778ff46a8a..165b4474f4 100644 --- a/include/cutlass/gemm/kernel/gemv.h +++ b/include/cutlass/gemm/kernel/gemv.h @@ -115,8 +115,8 @@ struct Gemv < static ComplexTransform const kTransformB = ComplexTransform::kNone; // thread block shape (kThreadCount, 1, 1) - static int const kThreadCount = (kThreadCount_ == 0) ? 32 : kThreadCount_; - static int const kThreadsPerRow = kThreadsPerRow_; + static int const kThreadCount = (kThreadCount_ <= 0) ? 32 : kThreadCount_; + static int const kThreadsPerRow = (kThreadsPerRow_ <= 0) ? 1 : kThreadsPerRow_; static int const kStages = 1; @@ -396,8 +396,8 @@ struct Gemv < using FragmentCompute = Array; // thread block shape (kThreadsPerRow, kThreadCount / kThreadsPerRow, 1) - static int const kThreadCount = (kThreadCount_ == 0) ? 128 : kThreadCount_; - static int const kThreadsPerRow = (kThreadsPerRow_ == 0) ? + static int const kThreadCount = (kThreadCount_ <= 0) ? 128 : kThreadCount_; + static int const kThreadsPerRow = (kThreadsPerRow_ <= 0) ? std::min(static_cast(kThreadCount / (kElementsPerAccess * sizeof(ElementA))), 16) : kThreadsPerRow_; @@ -429,51 +429,51 @@ struct Gemv < Arguments(): batch_count(0) { } Arguments( - MatrixCoord problem_size, - int32_t batch_count, - typename EpilogueOutputOp::Params output_op, - TensorRefA ref_A, - void const *ptr_B, - void const *ptr_C, - void *ptr_D, - int64_t batch_stride_A, - int64_t batch_stride_B, - int64_t batch_stride_C, - int64_t batch_stride_D - ): - problem_size(problem_size), - batch_count(batch_count), - output_op(output_op), - ref_A(ref_A), - ptr_B(static_cast(ptr_B)), - ptr_C(static_cast(ptr_C)), - ptr_D(static_cast(ptr_D)), - batch_stride_A(batch_stride_A), - batch_stride_B(batch_stride_B), - batch_stride_C(batch_stride_C), - batch_stride_D(batch_stride_D) + MatrixCoord problem_size, + int32_t batch_count, + typename EpilogueOutputOp::Params output_op, + TensorRefA ref_A, + void const *ptr_B, + void const *ptr_C, + void *ptr_D, + int64_t batch_stride_A, + int64_t batch_stride_B, + int64_t batch_stride_C, + int64_t batch_stride_D + ): + problem_size(problem_size), + batch_count(batch_count), + output_op(output_op), + ref_A(ref_A), + ptr_B(static_cast(ptr_B)), + ptr_C(static_cast(ptr_C)), + ptr_D(static_cast(ptr_D)), + batch_stride_A(batch_stride_A), + batch_stride_B(batch_stride_B), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { } Arguments( - MatrixCoord problem_size, - typename EpilogueOutputOp::Params output_op, - TensorRefA ref_A, - void const *ptr_B, - void const *ptr_C, - void *ptr_D - ): - Arguments( - problem_size, - 1, - output_op, - ref_A, - ptr_B, - ptr_C, - ptr_D, - 1, - 1, - 1, - 1) + MatrixCoord problem_size, + typename EpilogueOutputOp::Params output_op, + TensorRefA ref_A, + void const *ptr_B, + void const *ptr_C, + void *ptr_D + ): + Arguments( + problem_size, + 1, + output_op, + ref_A, + ptr_B, + ptr_C, + ptr_D, + 1, + 1, + 1, + 1) { } Status update(Arguments const &args) { diff --git a/include/cutlass/gemm/kernel/gemv_batched_strided.h b/include/cutlass/gemm/kernel/gemv_batched_strided.h index 613a279fdf..11490daf0c 100755 --- a/include/cutlass/gemm/kernel/gemv_batched_strided.h +++ b/include/cutlass/gemm/kernel/gemv_batched_strided.h @@ -149,7 +149,7 @@ CUTLASS_DEVICE void GemvBatchedStridedDevice( mma(problem_size.mnk(), accumulators, iterator_A, iterator_B, accumulators); // - // Epilogue (TODO: Epiloge as template argument) + // Epilogue // typename GemvKernel::FragmentCD fragment_CD; diff --git a/include/cutlass/gemm/kernel/grouped_problem_visitor.h b/include/cutlass/gemm/kernel/grouped_problem_visitor.h index 59a3657fbe..d013af0243 100644 --- a/include/cutlass/gemm/kernel/grouped_problem_visitor.h +++ b/include/cutlass/gemm/kernel/grouped_problem_visitor.h @@ -444,7 +444,6 @@ struct GroupedProblemVisitor +static inline bool +is_continous_k_aligned(GemmCoord problem_size, size_t alignmentA, size_t alignmentB) { + return (std::is_same::value && (problem_size.k() % alignmentA) == 0) || + (std::is_same::value && (problem_size.k() % alignmentB) == 0); +} + +} // namespace util + +///////////////////////////////////////////////////////////////////////////////////////////////// /// Argument structure struct UniversalArgumentsBase @@ -95,7 +107,9 @@ template < typename ThreadblockShape, typename ElementA, typename ElementB, - typename ElementC> + typename ElementC, + typename LayoutA, + typename LayoutB> struct UniversalParamsBase { // @@ -150,7 +164,18 @@ struct UniversalParamsBase if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { - int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + static const uint32_t CACHELINE_BYTES = 128; + static const size_t element_bytes_a = sizeof(ElementA); + static const size_t element_bytes_b = sizeof(ElementB); + static const size_t cacheline_elements_a = CACHELINE_BYTES / element_bytes_a; + static const size_t cacheline_elements_b = CACHELINE_BYTES / element_bytes_b; + + const bool cacheline_alignment_needed = + util::is_continous_k_aligned(problem_size, cacheline_elements_a, cacheline_elements_b); + + int const kAlignK = const_max( + const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), + cacheline_alignment_needed ? const_max(cacheline_elements_a, cacheline_elements_b) : 1); gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); if (gemm_k_size) { diff --git a/include/cutlass/gemm/kernel/rank_2k_grouped.h b/include/cutlass/gemm/kernel/rank_2k_grouped.h index 1c840e7aff..55955d4331 100644 --- a/include/cutlass/gemm/kernel/rank_2k_grouped.h +++ b/include/cutlass/gemm/kernel/rank_2k_grouped.h @@ -525,7 +525,7 @@ struct Rank2KGrouped { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/rank_2k_universal.h b/include/cutlass/gemm/kernel/rank_2k_universal.h index 6d1f4ac2ff..2775710d61 100644 --- a/include/cutlass/gemm/kernel/rank_2k_universal.h +++ b/include/cutlass/gemm/kernel/rank_2k_universal.h @@ -450,7 +450,7 @@ struct Rank2KUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/rank_k_universal.h b/include/cutlass/gemm/kernel/rank_k_universal.h index b7d1ad1958..188a4e70cf 100644 --- a/include/cutlass/gemm/kernel/rank_k_universal.h +++ b/include/cutlass/gemm/kernel/rank_k_universal.h @@ -403,7 +403,7 @@ struct RankKUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/sm70_gemm.hpp b/include/cutlass/gemm/kernel/sm70_gemm.hpp index d27993725d..e1fc4ec92f 100644 --- a/include/cutlass/gemm/kernel/sm70_gemm.hpp +++ b/include/cutlass/gemm/kernel/sm70_gemm.hpp @@ -45,13 +45,13 @@ template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class GridSwizzle_ + class TileScheduler_ > class GemmUniversal< ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - GridSwizzle_, + TileScheduler_, cute::enable_if_t>> { public: @@ -59,7 +59,7 @@ class GemmUniversal< // Type Aliases // using ProblemShape = ProblemShape_; - using GridSwizzle = GridSwizzle_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); @@ -77,6 +77,14 @@ class GemmUniversal< using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; + static_assert(cute::is_void_v or cute::is_same_v, + "SM70 kernel does not support specializing the tile scheduler."); + using TileScheduleTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, + cute::Shape, cute::Int<1>, cute::Int<1>>>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; @@ -88,9 +96,10 @@ class GemmUniversal< static_assert(cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); - static constexpr int SharedStorageSize = cute::max( + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = static_cast(cute::max( sizeof(typename CollectiveMainloop::SharedStorage), - sizeof(typename CollectiveEpilogue::SharedStorage)); + sizeof(typename CollectiveEpilogue::SharedStorage))); static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; @@ -102,6 +111,7 @@ class GemmUniversal< MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; }; // Kernel entry point API @@ -140,6 +150,12 @@ class GemmUniversal< return 0; } + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + static dim3 get_grid_shape(Params const& params) { int batch_count = 1; @@ -169,7 +185,7 @@ class GemmUniversal< CUTE_STATIC_ASSERT(is_static::value); // Separate out problem shape for convenience - // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp index 00fc230f5d..7ab238f6de 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma.hpp @@ -39,6 +39,7 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/trace.h" #include "cute/tensor.hpp" @@ -65,13 +66,13 @@ template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class GridSwizzle_ + class TileScheduler_ > class GemmUniversal< ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - GridSwizzle_, + TileScheduler_, cute::enable_if_t>> { public: @@ -79,7 +80,6 @@ class GemmUniversal< // Type Aliases // using ProblemShape = ProblemShape_; - using GridSwizzle = GridSwizzle_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); @@ -110,9 +110,16 @@ class GemmUniversal< static_assert(cute::is_same_v, "Mainloop and epilogue do not agree on accumulator value type."); - static constexpr int SharedStorageSize = cute::max( + static_assert(cute::is_void_v or cute::is_same_v, + "TMA kernel does not support specializing the tile scheduler."); + using TileScheduleTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + static constexpr int SharedStorageSize = static_cast(cute::max( sizeof(typename CollectiveMainloop::SharedStorage), - sizeof(typename CollectiveEpilogue::SharedStorage)); + sizeof(typename CollectiveEpilogue::SharedStorage))); static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; @@ -124,6 +131,7 @@ class GemmUniversal< MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; }; // Kernel entry point API @@ -163,36 +171,11 @@ class GemmUniversal< bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); - return implementable; - } - constexpr int tma_alignment_bits = 128; - constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; - auto M = get<0>(args.problem_shape); - auto N = get<1>(args.problem_shape); - auto K = get<2>(args.problem_shape); - // Contiguous dimension for the TMA tensor should be 128b aligned - implementable = std::is_same_v, layout::RowMajor> ? - K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; - implementable = implementable && (std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); - implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || - (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && - std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; } - - constexpr bool is_beta_supported = - CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default; - implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Scaling params don't meet ThreadEpilogueOp requirements.\n"); - return implementable; - } - + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); return implementable; } @@ -201,13 +184,18 @@ class GemmUniversal< return 0; } + static cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { auto cluster_shape = ClusterShape{}; auto tile_shape = TileShape{}; auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( + return TileScheduler::get_tiled_cta_shape_mnl( problem_shape_MNKL, tile_shape, cluster_shape); } @@ -237,8 +225,9 @@ class GemmUniversal< static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); int thread_idx = int(threadIdx.x); - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); // Issue Tma Descriptor Prefetch from a single thread if ((warp_idx == 0) && lane_predicate) { @@ -246,7 +235,7 @@ class GemmUniversal< } // Separate out problem shape for convenience - // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -291,6 +280,7 @@ class GemmUniversal< accumulators, k_tile_iter, k_tile_count, thread_idx, + block_rank_in_cluster, smem_buf, params.mainloop ); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index d708f82120..01d442746d 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -41,6 +41,8 @@ #include "cutlass/gemm/dispatch_policy.hpp" #include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" #include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + #include "cute/tensor.hpp" /////////////////////////////////////////////////////////////////////////////// @@ -53,13 +55,13 @@ template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class GridSwizzle_ + class TileScheduler_ > class GemmUniversal< ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - GridSwizzle_, + TileScheduler_, cute::enable_if_t>> { public: @@ -67,7 +69,6 @@ class GemmUniversal< // Type Aliases // using ProblemShape = ProblemShape_; - using GridSwizzle = GridSwizzle_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); @@ -95,11 +96,17 @@ class GemmUniversal< using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); + + static_assert(cute::is_void_v or cute::is_same_v, + "TMA warp-specialized kernel does not support specializing the tile scheduler."); + using TileScheduleTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; // Kernel level shared memory storage struct SharedStorage { + // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union union TensorStorage { using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -131,6 +138,7 @@ class GemmUniversal< MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; }; // Kernel entry point API @@ -170,36 +178,11 @@ class GemmUniversal< bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); - return implementable; - } - constexpr int tma_alignment_bits = 128; - constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; - auto M = get<0>(args.problem_shape); - auto N = get<1>(args.problem_shape); - auto K = get<2>(args.problem_shape); - // Contiguous dimension for the TMA tensor should be 128b aligned - implementable = std::is_same_v, layout::RowMajor> ? - K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; - implementable = implementable && (std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); - implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || - (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && - std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - return implementable; - } - - constexpr bool is_beta_supported = not cute::is_void_v && - CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default; - implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Scaling params don't meet ThreadEpilogueOp requirements.\n"); + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; } - + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); return implementable; } @@ -209,13 +192,19 @@ class GemmUniversal< return 0; } + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { auto cluster_shape = ClusterShape{}; auto tile_shape = TileShape{}; auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - return detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( + return TileScheduler::get_tiled_cta_shape_mnl( problem_shape_MNKL, tile_shape, cluster_shape); } @@ -242,15 +231,26 @@ class GemmUniversal< Producer = 0, Consumer = 1, }; + enum class ProducerWarpRole { + MainloopEpilogue = 0, + Warp1 = 1, + Warp2 = 2, + Warp3 = 3 + }; // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); int thread_idx = int(threadIdx.x); - int warp_idx = canonical_warp_idx(); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + // Issue Tma Descriptor Prefetch from a single thread if ((warp_idx == 0) && lane_predicate) { @@ -261,7 +261,7 @@ class GemmUniversal< // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer) { @@ -275,14 +275,14 @@ class GemmUniversal< // Epilogue Load pipeline using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; } epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); @@ -324,7 +324,7 @@ class GemmUniversal< static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); // Separate out problem shape for convenience - // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -358,44 +358,48 @@ class GemmUniversal< auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); auto k_tile_count = size<2>(gA); auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + [[maybe_unused]] auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); // Wait for all thread blocks in the Cluster cluster_wait_fn(); // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue}; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); if (warp_group_role == WarpGroupRole::Producer) { - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, - k_tile_iter, k_tile_count, - thread_idx, - shared_storage.tensors.mainloop - ); - // Update starting mainloop pipeline state for the pipeline drain - mainloop_pipe_producer_state.advance(k_tile_count); - // Make sure mainloop consumer has been waited upon before issuing epilogue load - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - - if (collective_epilogue.is_source_needed()) { - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue + if (producer_warp_role == ProducerWarpRole::MainloopEpilogue) { + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop ); - // Update starting load pipeline state for the pipeline drain - epi_load_pipe_producer_state.advance(c_tile_count); - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + // Ensure warp is converged before issuing epilogue loads + __syncwarp(); + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } } } else if (warp_group_role == WarpGroupRole::Consumer) { @@ -419,6 +423,7 @@ class GemmUniversal< ); // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = collective_epilogue.store( epi_load_pipeline, epi_load_pipe_consumer_state, @@ -432,6 +437,13 @@ class GemmUniversal< warp_group_thread_idx, shared_storage.tensors.epilogue ); + + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); } } }; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 74c06cf955..dfa18ad2c6 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -39,9 +39,10 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/gemm/gemm.h" #include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "cute/tensor.hpp" +#include "cutlass/trace.h" /////////////////////////////////////////////////////////////////////////////// @@ -53,13 +54,13 @@ template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class GridSwizzle_ + class TileScheduler_ > class GemmUniversal< ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - GridSwizzle_, + TileScheduler_, cute::enable_if_t>> { public: @@ -67,7 +68,6 @@ class GemmUniversal< // Type Aliases // using ProblemShape = ProblemShape_; - using GridSwizzle = GridSwizzle_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); @@ -94,14 +94,17 @@ class GemmUniversal< using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); - using PersistentTileSchedulerParams = typename detail::PersistentTileSchedulerSm90::Params; static_assert(ArchTag::kMinComputeCapability >= 90); + using TileScheduleTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = size(TiledMma{}) / NumThreadsPerWarpGroup; static constexpr uint32_t MaxThreadsPerBlock = size(TiledMma{}) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); static constexpr uint32_t MinBlocksPerMultiprocessor = 1; @@ -109,6 +112,9 @@ class GemmUniversal< static constexpr uint32_t LoadRegisterRequirement = 40; static constexpr uint32_t MmaRegisterRequirement = 232; + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + // Kernel level shared memory storage struct SharedStorage { struct TensorStorage : cute::aligned_struct<128> { @@ -125,6 +131,7 @@ class GemmUniversal< alignas(16) MainloopPipelineStorage mainloop; alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; } pipelines; }; @@ -137,6 +144,7 @@ class GemmUniversal< MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; }; // Kernel entry point API @@ -146,7 +154,8 @@ class GemmUniversal< MainloopParams mainloop; EpilogueParams epilogue; KernelHardwareInfo hw_info; - PersistentTileSchedulerParams scheduler; + TileSchedulerParams scheduler; + void* workspace; }; // @@ -159,14 +168,13 @@ class GemmUniversal< to_underlying_arguments(Arguments const& args, void* workspace) { CUTLASS_TRACE_HOST("to_underlying_arguments():"); - (void) workspace; auto problem_shape = args.problem_shape; if constexpr (detail::IF_SWAP_AB::value) { // swap M/N get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); } - auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); + auto problem_shape_MNKL = append<4>(problem_shape, 1); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; @@ -177,13 +185,19 @@ class GemmUniversal< } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); + return { args.mode, problem_shape, CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), - {args.hw_info.device_id, sm_count}, - detail::PersistentTileSchedulerSm90::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}) + hw_info, + scheduler, + workspace }; } @@ -193,50 +207,38 @@ class GemmUniversal< bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); - return implementable; - } - constexpr int tma_alignment_bits = 128; - constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; - auto M = get<0>(args.problem_shape); - auto N = get<1>(args.problem_shape); - auto K = get<2>(args.problem_shape); - // Contiguous dimension for the TMA tensor should be 128b aligned - implementable = std::is_same_v, layout::RowMajor> ? - K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; - implementable = implementable && (std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); - implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || - (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && - std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - return implementable; - } - - constexpr bool is_beta_supported = - CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default; - implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Scaling params don't meet ThreadEpilogueOp requirements.\n"); + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; } - + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); return implementable; } - static - int + static int get_workspace_size(Arguments const& args) { - return 0; + TileScheduler t; + return t.template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + TileScheduler t; + return t.template initialize_workspace( + args.scheduler, workspace, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); } // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info); + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); } static dim3 @@ -274,16 +276,26 @@ class GemmUniversal< Consumer0 = 1, Consumer1 = 2 }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); int thread_idx = int(threadIdx.x); - int warp_idx = canonical_warp_idx(); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; int mma_thread_idx = thread_idx % size(TiledMma{}); auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); // Issue Tma Descriptor Prefetch from a single thread if ((warp_idx == 0) && lane_predicate) { @@ -294,7 +306,7 @@ class GemmUniversal< // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) { mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { @@ -308,14 +320,14 @@ class GemmUniversal< // Epilogue Load pipeline using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; } epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); @@ -326,6 +338,11 @@ class GemmUniversal< epi_store_pipeline_params.always_wait = true; EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + // Initialize starting pipeline states for the collectives // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; @@ -351,7 +368,7 @@ class GemmUniversal< } (); // Separate out problem shape for convenience - // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -376,12 +393,12 @@ class GemmUniversal< auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - detail::PersistentTileSchedulerSm90 scheduler; - auto work_tile_info = scheduler.get_current_work(params.scheduler); + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue}; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); // Wait for all thread blocks in the Cluster cluster_wait_fn(); @@ -389,76 +406,110 @@ class GemmUniversal< if (warp_group_role == WarpGroupRole::Producer) { cutlass::arch::warpgroup_reg_dealloc(); - while (work_tile_info.is_valid_tile) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Slice with our work tile coordinates to construct mainloop tensor views - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, - k_tile_iter, k_tile_count, - thread_idx, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(k_tile_count); - - if (collective_epilogue.is_source_needed()) { - collective_epilogue.load( - epi_load_pipeline, - epi_load_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - tiled_mma, - warp_group_thread_idx, - shared_storage.tensors.epilogue + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid_tile) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with our work tile coordinates to construct mainloop tensor views + Tensor gA_presplit = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB_presplit = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Split operands A and B along the K dimension according to work_tile_info + Tensor gA = TileScheduler::split_MK(gA_presplit, work_tile_info); // (BLK_N,BLK_K,k_split_iters) + Tensor gB = TileScheduler::split_NK(gB_presplit, work_tile_info); // (BLK_N,BLK_K,k_split_iters) + + auto work_k_tile_count = size<2>(gA); + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA_presplit)); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, work_k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop ); // Update starting pipeline state for the next tile - epi_load_pipe_producer_state.advance(c_tile_count); - } - - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(params.scheduler); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - if (collective_epilogue.is_source_needed()) { + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + load_order_barrier.wait(); + while (work_tile_info.is_valid_tile) { + if (TileScheduler::compute_epilogue(work_tile_info)) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } + } // Epilogue Producer Warp End } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; while (work_tile_info.is_valid_tile) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); // Allocate the the accumulators for the (M,N) blk_shape - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) collective_mainloop.mma( mainloop_pipeline, mainloop_pipe_consumer_state, accumulators, - k_tile_count, + work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop, params.mainloop @@ -468,35 +519,73 @@ class GemmUniversal< collective_mainloop.mma_tail( mainloop_pipeline, mainloop_pipe_consumer_state, - k_tile_count + work_k_tile_count ); + // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(k_tile_count); + mainloop_pipe_consumer_state.advance(work_k_tile_count); + + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - // Epilogue and write to gD - collective_epilogue.store( + if (TileScheduler::compute_epilogue(work_tile_info)) { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + if (do_store_tail) { + collective_epilogue.store_tail( epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, - epi_store_pipe_producer_state, - problem_shape_MNKL, - blk_shape, - blk_coord, - accumulators, - tiled_mma, - mma_thread_idx, - shared_storage.tensors.epilogue + epi_store_pipe_producer_state ); - // Update starting load/store pipeline states for the next tile - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(params.scheduler); - } // Scheduler work fetch loop + } } // Consumer Warp Groups End } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo + fetch_next_work( + typename TileScheduler::WorkTileInfo& work_tile_info, + TileScheduler& scheduler) const { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } }; /////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index e7cc85f0a7..77c88c32a0 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -55,13 +55,13 @@ template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class GridSwizzle_ + class TileScheduler_ > class GemmUniversal< ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - GridSwizzle_, + TileScheduler_, cute::enable_if_t>> { public: @@ -69,7 +69,6 @@ class GemmUniversal< // Type Aliases // using ProblemShape = ProblemShape_; - using GridSwizzle = GridSwizzle_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); @@ -87,7 +86,6 @@ class GemmUniversal< using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; - using PersistentTileSchedulerParams = typename detail::PersistentTileSchedulerSm90::Params; static_assert(ArchTag::kMinComputeCapability >= 90); // Epilogue derived types @@ -98,8 +96,14 @@ class GemmUniversal< using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(cute::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); + + static_assert(cute::is_void_v or cute::is_same_v, + "Ping-pong kernel only supports the default scheduler."); + using TileScheduleTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; static constexpr uint32_t NumLoadWarpGroups = 1; static constexpr uint32_t NumMmaWarpGroups = 2; @@ -110,6 +114,9 @@ class GemmUniversal< static constexpr uint32_t LoadRegisterRequirement = 40; static constexpr uint32_t MmaRegisterRequirement = 232; + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1,2>; + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue static constexpr uint32_t StagesPerMathWarpGroup = 2; using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< @@ -133,6 +140,7 @@ class GemmUniversal< alignas(16) MainloopPipelineStorage mainloop; alignas(16) EpiLoadPipelineStorage epi_load; alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; } pipelines; }; @@ -145,6 +153,7 @@ class GemmUniversal< MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; }; // Kernel entry point API @@ -154,7 +163,7 @@ class GemmUniversal< MainloopParams mainloop; EpilogueParams epilogue; KernelHardwareInfo hw_info; - PersistentTileSchedulerParams scheduler; + TileSchedulerParams scheduler; }; // @@ -174,7 +183,7 @@ class GemmUniversal< get<0>(problem_shape) = get<1>(args.problem_shape); get<1>(problem_shape) = get<0>(args.problem_shape); } - auto problem_shape_MNKL = append<4>(problem_shape, Int<1>{}); + auto problem_shape_MNKL = append<4>(problem_shape, 1); // Get SM count if needed, otherwise use user supplied SM count int sm_count = args.hw_info.sm_count; @@ -185,13 +194,15 @@ class GemmUniversal< } CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + return { args.mode, problem_shape, CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), - {args.hw_info.device_id, sm_count}, - detail::PersistentTileSchedulerSm90::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}) + hw_info, + TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler) }; } @@ -201,36 +212,11 @@ class GemmUniversal< bool implementable = (args.mode == GemmUniversalMode::kGemm) or (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Size don't meet the requirements.\n"); - return implementable; - } - constexpr int tma_alignment_bits = 128; - constexpr int min_tma_aligned_elements = tma_alignment_bits / cutlass::sizeof_bits::value; - auto M = get<0>(args.problem_shape); - auto N = get<1>(args.problem_shape); - auto K = get<2>(args.problem_shape); - // Contiguous dimension for the TMA tensor should be 128b aligned - implementable = std::is_same_v, layout::RowMajor> ? - K % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0; - implementable = implementable && (std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : K % min_tma_aligned_elements == 0); - implementable = implementable && (!cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value || - (cutlass::epilogue::collective::detail::IF_EPILOGUE_USES_TMA::value && - std::is_same_v, layout::RowMajor> ? - N % min_tma_aligned_elements == 0 : M % min_tma_aligned_elements == 0)); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); return implementable; } - - constexpr bool is_beta_supported = - CollectiveEpilogue::ThreadEpilogueOp::kScale == cutlass::epilogue::thread::ScaleType::Default; - implementable = is_beta_supported || (args.epilogue.thread.beta == 0 && args.epilogue.thread.beta_ptr == nullptr); - if (!implementable) { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Scaling params don't meet ThreadEpilogueOp requirements.\n"); - return implementable; - } - + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); return implementable; } @@ -240,11 +226,21 @@ class GemmUniversal< return 0; } + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + // Computes the kernel launch grid shape based on runtime parameters static dim3 get_grid_shape(Params const& params) { // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - return detail::PersistentTileSchedulerSm90::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info); + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); } static dim3 @@ -277,15 +273,25 @@ class GemmUniversal< Consumer0 = 1, Consumer1 = 2 }; + enum class ProducerWarpRole { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; // Kernel level shared memory storage SharedStorage& shared_storage = *reinterpret_cast(smem_buf); int thread_idx = int(threadIdx.x); - int warp_idx = canonical_warp_idx(); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); // Issue Tma Descriptor Prefetch from a single thread if ((warp_idx == 0) && lane_predicate) { @@ -296,7 +302,7 @@ class GemmUniversal< // Mainloop Load pipeline using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) { mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { @@ -310,14 +316,14 @@ class GemmUniversal< // Epilogue Load pipeline using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer) { + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; } if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; } epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = 1; // 1 thread issues TMA load + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); @@ -328,6 +334,11 @@ class GemmUniversal< epi_store_pipeline_params.always_wait = true; EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; // DMA Load WG will not participate in these Ordered Barrier syncs params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); @@ -359,7 +370,7 @@ class GemmUniversal< } (); // Separate out problem shape for convenience - // Optionally append _1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); auto M = get<0>(problem_shape_MNKL); auto N = get<1>(problem_shape_MNKL); @@ -384,7 +395,7 @@ class GemmUniversal< auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - detail::PersistentTileSchedulerSm90 scheduler; + TileScheduler scheduler{params.scheduler}; if (warp_group_role == WarpGroupRole::Consumer1) { // Advance 2nd Math WG to the next work tile for the startup @@ -394,11 +405,11 @@ class GemmUniversal< epi_load_pipe_consumer_state.advance(c_tile_count); epi_store_pipe_producer_state.advance(d_tile_count); } - auto work_tile_info = scheduler.get_current_work(params.scheduler); + auto work_tile_info = scheduler.get_current_work(); // In a warp specialized kernel, collectives expose data movement and compute operations separately CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue{params.epilogue}; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); // Wait for all thread blocks in the Cluster cluster_wait_fn(); @@ -406,32 +417,61 @@ class GemmUniversal< if (warp_group_role == WarpGroupRole::Producer) { cutlass::arch::warpgroup_reg_dealloc(); - while (work_tile_info.is_valid_tile) { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Slice with our work tile coordinates to construct mainloop tensor views - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - - collective_mainloop.load( - mainloop_pipeline, - mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, - k_tile_iter, k_tile_count, - thread_idx, - shared_storage.tensors.mainloop - ); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(k_tile_count); - - if (collective_epilogue.is_source_needed()) { + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid_tile) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with our work tile coordinates to construct mainloop tensor views + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, params.mainloop.tma_load_a, + gB, params.mainloop.tma_load_b, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { + load_order_barrier.wait(); + while (work_tile_info.is_valid_tile) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = collective_epilogue.load( epi_load_pipeline, epi_load_pipe_producer_state, @@ -439,23 +479,18 @@ class GemmUniversal< blk_shape, blk_coord, tiled_mma, - warp_group_thread_idx, + lane_idx, shared_storage.tensors.epilogue ); - // Update starting pipeline state for the next tile - epi_load_pipe_producer_state.advance(c_tile_count); - } // Get next work tile scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(params.scheduler); + work_tile_info = scheduler.get_current_work(); } // Scheduler work fetch loop - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - if (collective_epilogue.is_source_needed()) { + // Make sure all Consumer Warp Groups have been waited upon collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } + } // Epilogue Producer Warp End } // Producer Warp Group End else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { @@ -500,6 +535,7 @@ class GemmUniversal< math_wg_order_barrier.wait(); // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = collective_epilogue.store( epi_load_pipeline, epi_load_pipe_consumer_state, @@ -513,19 +549,31 @@ class GemmUniversal< warp_group_thread_idx, shared_storage.tensors.epilogue ); - // Update starting load/store pipeline states for the next tile - epi_load_pipe_consumer_state.advance(c_tile_count * NumMmaWarpGroups); - epi_store_pipe_producer_state.advance(d_tile_count * NumMmaWarpGroups); - // Wait for all TMA stores to complete - epi_store_pipeline.producer_tail(epi_store_pipe_producer_state); + // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels + // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives + // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. + auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] = + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); // Cue for next Math WG's Epilogue to start math_wg_order_barrier.arrive(); // Get next work tile scheduler.advance_to_next_work(NumMmaWarpGroups); - work_tile_info = scheduler.get_current_work(params.scheduler); + work_tile_info = scheduler.get_current_work(); } // Scheduler work fetch loop } // Consumer Warp Groups End } diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index 744ee7a553..3b78f15b55 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -33,6 +33,7 @@ #include "cutlass/fast_math.h" #include "cutlass/kernel_hardware_info.hpp" #include "cute/layout.hpp" +#include "cute/tensor.hpp" namespace cutlass::gemm::kernel::detail { @@ -45,73 +46,201 @@ class PersistentTileSchedulerSm90 { // private: - uint64_t current_work_linear_idx_{static_cast((int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y))}; - uint64_t grid_blocks_total_{static_cast(int(gridDim.x) * int(gridDim.y))}; + uint64_t current_work_linear_idx_; +public: struct WorkTileInfo { int32_t M_idx = 0; int32_t N_idx = 0; int32_t L_idx = 0; - uint32_t is_valid_tile = false; + bool is_valid_tile = false; }; // // Methods // -public: + enum class RasterOrder { + AlongM, + AlongN + }; + + struct Arguments { + int max_swizzle_size = 1; + }; struct Params { + + FastDivmodU64 divmod_cluster_shape_major_{}; + FastDivmodU64 divmod_cluster_shape_minor_{}; FastDivmodU64 divmod_batch_{}; - FastDivmodU64 divmod_grid_y_{}; - FastDivmodU64 divmod_blk_m_{}; + FastDivmodU64 divmod_cluster_blk_major_{}; uint64_t blocks_per_problem_ = 0; + int32_t log_swizzle_size_ = 0; + RasterOrder raster_order_ = RasterOrder::AlongN; }; + // Sink scheduler params as a member + Params scheduler_params; + + // + // Methods + // template static Params - to_underlying_arguments(ProblemShapeMNKL problem_shape_mnkl, TileShape tile_shape, ClusterShape cluster_shape) { + to_underlying_arguments( + ProblemShapeMNKL problem_shape_mnkl, + TileShape tile_shape, + ClusterShape cluster_shape, + [[maybe_unused]] KernelHardwareInfo const& hw_info, + Arguments const& arguments, + [[maybe_unused]] void* workspace=nullptr) { + // We only need the tile and cluster shape during scheduler setup, so let FTAD do the magic - static_assert(is_static::value); - static_assert(is_static::value); + static_assert(cute::is_static::value); + static_assert(cute::is_static::value); // Round up to nearest multiple of cluster dim along each mode - auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_blk_shape_mnl( + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_cta_shape_mnl( problem_shape_mnkl, tile_shape, cluster_shape); - return { - FastDivmodU64(problem_blocks_m * problem_blocks_n), - FastDivmodU64(size<1>(cluster_shape)), - FastDivmodU64(problem_blocks_m), - problem_blocks_m * problem_blocks_n * problem_blocks_l - }; + // Round up to nearest multiple of swizzle_size along each mode + auto log_swizzle_size = get_log_swizzle_size(problem_blocks_m, problem_blocks_n, arguments.max_swizzle_size); + problem_blocks_m = round_up(problem_blocks_m, (1 << log_swizzle_size) * cute::size<0>(cluster_shape)); + problem_blocks_n = round_up(problem_blocks_n, (1 << log_swizzle_size) * cute::size<1>(cluster_shape)); + + + RasterOrder raster_order; + raster_order = get_rasterization_order(problem_shape_mnkl, tile_shape); + if (raster_order == RasterOrder::AlongN) { + return { + FastDivmodU64(cute::size<1>(cluster_shape)), + FastDivmodU64(cute::size<0>(cluster_shape)), + FastDivmodU64(problem_blocks_m * problem_blocks_n), + FastDivmodU64(problem_blocks_n / cute::size<1>(cluster_shape)), + problem_blocks_m * problem_blocks_n * problem_blocks_l, + log_swizzle_size, + raster_order + }; + } + else { + return { + FastDivmodU64(cute::size<0>(cluster_shape)), + FastDivmodU64(cute::size<1>(cluster_shape)), + FastDivmodU64(problem_blocks_m * problem_blocks_n), + FastDivmodU64(problem_blocks_m / cute::size<0>(cluster_shape)), + problem_blocks_m * problem_blocks_n * problem_blocks_l, + log_swizzle_size, + raster_order + }; + } + } + + CUTLASS_HOST_DEVICE + PersistentTileSchedulerSm90() { }; + + CUTLASS_DEVICE explicit PersistentTileSchedulerSm90(Params const& params_) : scheduler_params(params_) { + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) + if (params_.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = static_cast(int(blockIdx.x) + (int(blockIdx.y) * int(gridDim.x))); + } + else { + current_work_linear_idx_ = static_cast((int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y)); + } +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif } - PersistentTileSchedulerSm90() = default; + CUTLASS_DEVICE + WorkTileInfo + get_current_work() const { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } CUTLASS_DEVICE WorkTileInfo - get_current_work(Params const& scheduler_params) const { + get_current_work_for_linear_idx(uint64_t linear_idx) const { // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices uint64_t work_idx_l, remainder; - scheduler_params.divmod_batch_(work_idx_l, remainder, current_work_linear_idx_); + scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx); - uint64_t blk_per_grid_dim, dontcare; - scheduler_params.divmod_grid_y_(blk_per_grid_dim, dontcare, remainder); + uint64_t blk_per_grid_dim = scheduler_params.divmod_cluster_shape_minor_.divide(remainder); - uint64_t block_idx_m, block_idx_n; - scheduler_params.divmod_blk_m_(block_idx_n, block_idx_m, blk_per_grid_dim); - int32_t work_idx_m = static_cast(block_idx_m); - int32_t work_idx_n = static_cast((block_idx_n * gridDim.y) + blockIdx.y); + auto [work_idx_m, work_idx_n] = get_work_idx_m_and_n(blk_per_grid_dim, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cluster_blk_major_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); - return {work_idx_m, work_idx_n, static_cast(work_idx_l), current_work_linear_idx_ < scheduler_params.blocks_per_problem_}; + return {work_idx_m, work_idx_n, static_cast(work_idx_l), linear_idx < scheduler_params.blocks_per_problem_}; } CUTLASS_DEVICE void advance_to_next_work(uint32_t advance_count = 1) { - current_work_linear_idx_ += grid_blocks_total_ * advance_count; + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) + current_work_linear_idx_ += static_cast(int(gridDim.x) * int(gridDim.y) * int(gridDim.z)) * advance_count; +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + } + + // get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle + static CUTLASS_DEVICE + cute::tuple + get_work_idx_m_and_n( + uint64_t blk_per_grid_dim, + FastDivmodU64 const& divmod_cluster_shape_major, + FastDivmodU64 const& divmod_cluster_shape_minor, + FastDivmodU64 const& divmod_cluster_blk_major, + int32_t log_swizzle_size, + RasterOrder raster_order) { + + uint64_t cluster_id, cluster_major_offset = 0, cluster_minor_offset = 0; + divmod_cluster_shape_major(cluster_id, cluster_major_offset, blk_per_grid_dim); + // MSVC requires protecting use of CUDA-specific nonstandard syntax, + // like blockIdx and gridDim, with __CUDA_ARCH__. +#if defined(__CUDA_ARCH__) + if (raster_order == RasterOrder::AlongN) { + cluster_minor_offset = blockIdx.x; + } + else { + cluster_minor_offset = blockIdx.y; + } +#else + CUTLASS_ASSERT(false && "This line should never be reached"); +#endif + + uint64_t cluster_idx_minor, cluster_idx_major; + + uint64_t cluster_idx_minor_div_swizzle, extra, offset; + + offset = cluster_id & ((1 << log_swizzle_size) - 1); + extra = cluster_id >> log_swizzle_size; + + divmod_cluster_blk_major(cluster_idx_minor_div_swizzle, cluster_idx_major, extra); + + cluster_idx_minor = cluster_idx_minor_div_swizzle * (1 << log_swizzle_size) + offset; + + auto minor_work_idx = static_cast(cluster_idx_minor * divmod_cluster_shape_minor.divisor + + cluster_minor_offset); + auto major_work_idx = static_cast(cluster_idx_major * divmod_cluster_shape_major.divisor + + cluster_major_offset); + + if (raster_order == RasterOrder::AlongN) { + return {minor_work_idx, major_work_idx}; + } + else { + return {major_work_idx, minor_work_idx}; + } + } // Given the inputs, computes the total number of output blocks this problem will compute over @@ -119,38 +248,93 @@ class PersistentTileSchedulerSm90 { template CUTLASS_HOST_DEVICE static dim3 - get_tiled_blk_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape blk_shape, ClusterShape cluster_shape) { + get_tiled_cta_shape_mnl(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape, ClusterShape cluster_shape) { // Across M and N is our Cluster tile, so we must round up the blocks to the nearest whole number of Cluster tiles - auto blk_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(blk_shape))); - auto blk_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(blk_shape))); + auto cta_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape))); + auto cta_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape))); // Round up to nearest multiple of cluster dim along each mode - int problem_blocks_m = round_up(blk_m, cute::size<0>(cluster_shape)); - int problem_blocks_n = round_up(blk_n, cute::size<1>(cluster_shape)); + int problem_blocks_m = round_up(cta_m, cute::size<0>(cluster_shape)); + int problem_blocks_n = round_up(cta_n, cute::size<1>(cluster_shape)); // Cluster tile does not span the batch mode, so no extra rounding up required for it int problem_blocks_l = int(cute::size<3>(problem_shape_mnkl)); return {uint32_t(problem_blocks_m), uint32_t(problem_blocks_n), uint32_t(problem_blocks_l)}; } + CUTLASS_HOST_DEVICE + static int32_t + get_log_swizzle_size(int problem_ctas_m, int problem_ctas_n, int max_swizzle_size) { + int min_cta_dim = min(problem_ctas_m, problem_ctas_n); + if (max_swizzle_size >= 8 && min_cta_dim >= 6) { + return 3; + } + else if (max_swizzle_size >= 4 && min_cta_dim >= 3) { + return 2; + } + else if (max_swizzle_size >= 2 && min_cta_dim >= 2) { + return 1; + } + else { + return 0; + } + } + // Given the inputs, computes the physical grid we should launch. template CUTLASS_HOST_DEVICE static dim3 - get_grid_shape(ProblemShapeMNKL problem_shape_mnk, BlockShape blk_shape, ClusterShape cluster_shape, KernelHardwareInfo hw_info) { + get_grid_shape( + ProblemShapeMNKL problem_shape_mnk, + BlockShape cta_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + Arguments arguments, + bool truncate_by_problem_size=true) { + int const sm_count = hw_info.sm_count; CUTLASS_TRACE_HOST("get_grid_shape(): Persistent schedule grid plan using SM count = " << sm_count); + // Compute the total number of output tiles our problem has - auto problem_shape_MNKL = append<4>(problem_shape_mnk, Int<1>{}); + auto problem_shape_MNKL = cute::append<4>(problem_shape_mnk, cute::Int<1>{}); auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = - get_tiled_blk_shape_mnl(problem_shape_MNKL, blk_shape, cluster_shape); + get_tiled_cta_shape_mnl(problem_shape_MNKL, cta_shape, cluster_shape); + + // Round up to nearest multiple of swizzle_size along each mode + auto swizzle_size = 1 << get_log_swizzle_size(problem_blocks_m, problem_blocks_n, arguments.max_swizzle_size); + problem_blocks_m = round_up(problem_blocks_m, swizzle_size * cute::size<0>(cluster_shape)); + problem_blocks_n = round_up(problem_blocks_n, swizzle_size * cute::size<1>(cluster_shape)); + int problem_blocks_total = problem_blocks_m * problem_blocks_n * problem_blocks_l; - dim3 launch_grid(1, cute::size<1>(cluster_shape), 1); + RasterOrder raster_order; + raster_order = get_rasterization_order(problem_shape_mnk, cta_shape); + dim3 launch_grid; + + if (raster_order == RasterOrder::AlongN) { + launch_grid = dim3(cute::size<0>(cluster_shape), 1, 1); + } + else { + launch_grid = dim3(1, cute::size<1>(cluster_shape), 1); + } + + auto possibly_truncate = [&](int x, int y) { + if (truncate_by_problem_size) { + return std::min(x, y); + } + else { + return x; + } + }; - // The else path is generic, however, we can avoid some divs if we know Cluster size is 1 + // The else path is generic, however, we can avoid some divs if we know cluster size is 1 if constexpr (size(cluster_shape) == 1) { - launch_grid.x = std::min(sm_count, problem_blocks_total); + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = possibly_truncate(sm_count, problem_blocks_total); + } + else { + launch_grid.x = possibly_truncate(sm_count, problem_blocks_total); + } } else { /* @@ -161,22 +345,105 @@ class PersistentTileSchedulerSm90 { constexpr int max_sm_per_gpc = 18; // Provided SM count could possibly be less than the assumed maximum SMs per GPC int const min_num_gpc = sm_count < max_sm_per_gpc ? 1 : sm_count / max_sm_per_gpc; - int const max_blk_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(cluster_shape)); - int blk_per_device = min_num_gpc * max_blk_occupancy_per_gpc; + int const max_cta_occupancy_per_gpc = max_sm_per_gpc - (max_sm_per_gpc % size(cluster_shape)); + int cta_per_device = min_num_gpc * max_cta_occupancy_per_gpc; // The calculation below allows for larger grid size launch for different GPUs. int const num_gpc_residual = sm_count < max_sm_per_gpc ? 0 : sm_count % max_sm_per_gpc; - int const max_blk_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % size(cluster_shape)); - blk_per_device += max_blk_occupancy_per_residual_gpc; + int const max_cta_occupancy_per_residual_gpc = num_gpc_residual - (num_gpc_residual % size(cluster_shape)); + cta_per_device += max_cta_occupancy_per_residual_gpc; - blk_per_device = sm_count < blk_per_device ? sm_count : blk_per_device; + cta_per_device = sm_count < cta_per_device ? sm_count : cta_per_device; - launch_grid.x = std::min( - blk_per_device / size<1>(cluster_shape), - problem_blocks_total / size<1>(cluster_shape)); + if (raster_order == RasterOrder::AlongN) { + launch_grid.y = possibly_truncate( + cta_per_device / cute::size<0>(cluster_shape), + problem_blocks_total / cute::size<0>(cluster_shape)); + } + else { + launch_grid.x = possibly_truncate( + cta_per_device / cute::size<1>(cluster_shape), + problem_blocks_total / cute::size<1>(cluster_shape)); + } } return launch_grid; } + + template + CUTLASS_HOST_DEVICE static RasterOrder get_rasterization_order(ProblemShapeMNKL problem_shape_mnkl, BlockShape cta_shape) { + auto tiles_m = cute::size(cute::ceil_div(cute::shape<0>(problem_shape_mnkl), cute::shape<0>(cta_shape))); + auto tiles_n = cute::size(cute::ceil_div(cute::shape<1>(problem_shape_mnkl), cute::shape<1>(cta_shape))); + + if (tiles_n > tiles_m) { + return RasterOrder::AlongM; + } + + return RasterOrder::AlongN; + } + + // Splits an input tensor with MxK according to the splitting configuration specified by work_tile_info. + // Since the basic tile scheduler does not split output tiles, this method is a no-op. + template + CUTLASS_DEVICE + static auto + split_MK(cute::Tensor const& tensor, WorkTileInfo const&) { + return tensor; + } + + // Splits an input tensor with NxK tiles according to the splitting configuration specified by work_tile_info. + // Since the basic tile scheduler does not split output tiles, this method is a no-op. + template + CUTLASS_DEVICE + static auto + split_NK(cute::Tensor const& tensor, WorkTileInfo const&) { + return tensor; + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the basic tile scheduler, this is always true. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const&) { + return true; + } + + // Performs the reduction across splits for a given output tile. Since this scheduler does + // not split output tiles, no reduction is needed. + template + CUTLASS_DEVICE + static void + fixup(Params const&, WorkTileInfo const&, FrgTensorC&, uint32_t, uint32_t) {} + + // Returns whether the current WorkTileInfo passed in should continue to be used. Since + // this scheduler only schedules work in units of single, full output tiles, the WorkTileInfo + // passed in should not be used after having been processed. + CUTLASS_DEVICE + static bool + continue_current_work(WorkTileInfo&) { + return false; + } + + // The basic tile scheduler does not require any additional workspace + template + static int + get_workspace_size(Arguments const&, ProblemShape, KernelHardwareInfo const&, uint32_t) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(Arguments const&, void*, cudaStream_t, ProblemShape, KernelHardwareInfo const&, uint32_t) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape problem_shape, TileShape tile_shape) { + // All work units returned by this scheduler cover the entire K iteration + // space of the output tile assigned to the work unit. + return cute::size(cute::ceil_div(cute::get<2>(problem_shape), cute::get<2>(tile_shape))); + } }; } // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp new file mode 100644 index 0000000000..4049255eb8 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -0,0 +1,992 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/barrier.h" +#include "cutlass/block_striped.h" +#include "cutlass/fast_math.h" +#include "cutlass/workspace.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/layout.hpp" +#include "cute/tensor.hpp" + +namespace cutlass::gemm::kernel::detail { + +// Persistent Thread Block (TB) scheduler leveraging stream-K decomposition +template < + class TileShape, + class ClusterShape +> +class PersistentTileSchedulerSm90StreamK { + // + // Data members + // + +private: + using UnderlyingScheduler = PersistentTileSchedulerSm90; +public: + using RasterOrder = UnderlyingScheduler::RasterOrder; +private: + using UnderlyingArguments = typename UnderlyingScheduler::Arguments; + using UnderlyingParams = typename UnderlyingScheduler::Params; + + uint64_t current_work_linear_idx_ = 0; + + // Minimum number of k iterations that can be assigned to a stream-K unit + static constexpr uint32_t min_iters_per_sk_unit_ = 2; + + // Use a dummy barrier manager to simply get the type used to store the barrier + using BarrierType = typename NamedBarrierManager<1>::T; + +public: + + struct WorkTileInfo { + int32_t M_idx = 0; + int32_t N_idx = 0; + int32_t K_idx = 0; + int32_t L_idx = 0; + bool is_valid_tile = false; + + // Number of splits to be used in computing the {L_idx, M_idx, N_idx} output tile. + // Splits = 1 indicates that this is a data-parallel block. + uint32_t splits = 1; + + // Number of k iterations to compute for the current tile + uint32_t k_tile_count = 0; + + // Number of k iterations remaining for the work unit as a whole + uint32_t k_tile_remaining = 0; + + // Whether this unit of work is the final split for the given tile + bool is_final_split = true; + }; + + struct Arguments { + + Arguments() = default; + Arguments(Arguments const&) = default; + Arguments(Arguments&&) = default; + + CUTLASS_HOST_DEVICE + Arguments& + operator=(Arguments const& args) { + splits = args.splits; + return *this; + } + + CUTLASS_HOST_DEVICE + Arguments& + operator=(Arguments&& args) noexcept { + splits = args.splits; + return *this; + } + + CUTLASS_HOST_DEVICE + Arguments(int splits_) : splits(splits_) {} + + // The splitting factor to be used in a split-K decomposition of the problem. + // If this is set to a value greater than 1, stream-K decomposition logic + // is bypassed in favor of a split-K decomposition. + int splits = 1; + const int max_swizzle_size = 1; + }; + + struct Params { + FastDivmodU64 divmod_cluster_shape_major_{}; + FastDivmodU64 divmod_cluster_shape_minor_{}; + FastDivmodU64 divmod_batch_{}; + FastDivmodU64 divmod_k_{}; + FastDivmodU64 divmod_cluster_blk_major_{}; + + int32_t log_swizzle_size_ = 0; + + + uint64_t units_per_problem_ = 0; + RasterOrder raster_order_ = RasterOrder::AlongN; + ClusterShape cluster_shape_{}; + + // The splitting factor to be used in a split-K decomposition of the problem. + // If this is set to a value greater than 1, stream-K decomposition logic + // is bypassed in favor of a split-K decomposition. + uint32_t splits_ = 1; + + // Number of tiled k iterations required to compute a single output tile. + uint32_t k_iter_per_tile_ = 0; + + // Number of stream-K or split-K work units that compute an extra k iteration. + // This is done to handle residuals in dividing up the k iteration space. + // For stream-K, since the actual assignment of work to stream-K units will be done + // at the granularity of a cluster, we store only the number of big clusters. + uint32_t big_units_ = 0; + + // Workspace for holding partial accumulators to be reduced across stream-K/split-K units + void* reduction_workspace_ = nullptr; + + // Number of tiles covered by stream-K work units + uint32_t sk_tiles_ = 0; + + // Number of work units computing stream-K tiles + uint32_t sk_units_ = 0; + + // Number of tiled k iterations computed by each stream-K work unit. This + // can potentially cover more than one output tile. + uint32_t k_iter_per_sk_unit_ = 0; + }; + + // Sink scheduler params as a member + Params scheduler_params; + + // + // Methods + // + + template + static Params + to_underlying_arguments( + ProblemShape problem_shape_mnkl, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo const& hw_info, + Arguments const& args, + void* workspace) { + + static_assert(cute::is_static::value); + static_assert(cute::is_static::value); + + // Round up to nearest multiple of cluster dim along each mode + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_cta_shape_mnl( + problem_shape_mnkl, tile_shape, cluster_shape); + + uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; + + // Number of k iterations each tile computes (this is just the number of k iterations + // in the problem's K dimension) + uint32_t k_iter_per_tile = (cute::size<2>(problem_shape_mnkl) + cute::size<2>(tile_shape) - 1) / cute::size<2>(tile_shape); + + UnderlyingArguments underlying_args; + underlying_args.max_swizzle_size = 1; + UnderlyingParams underlying_params = UnderlyingScheduler::to_underlying_arguments( + problem_shape_mnkl, tile_shape, cluster_shape, hw_info, underlying_args, workspace); + + void* reduction_workspace = nullptr; + + if (workspace != nullptr) { + // Reduction workspace is at the beginning of the workspace. Lock workspace follows. + reduction_workspace = workspace; + } + + if (args.splits > 1) { + // Short circuit to basic split-K decomposition + + // Don't split by more than the available number of SMs + auto splits = args.splits > hw_info.sm_count ? hw_info.sm_count : args.splits; + + // Don't split by more than the K tile iterations + // + // splits is almost certainly nonnegative here (e.g., hw_info.sm_count, + // despite being an int, is a count), so it can safely be converted to unsigned + // in the comparison to avoid a signed-unsigned comparison warning-as-error. + splits = static_cast(splits) > k_iter_per_tile ? k_iter_per_tile : splits; + + return get_params_basic( + underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, + splits, k_iter_per_tile, reduction_workspace); + } + + // Calculate the maximum number of blocks from clusters of shape cluster_shape that we + // can fit within sm_count SMs. + dim3 grid = get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args); + uint64_t ctas_per_wave = grid.x * grid.y; + + // The number of output tiles to be computed in stream-K and data-parallel fashion, respectively. + uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave); + uint64_t dp_tiles = output_tiles - sk_tiles; + + // Calculate the number of work units covering the data-parallel and stream-K tiles. + // A "work unit" is a single index in the linearized ID space used by the scheduler. + // We distinguish it from a "block," which is typically tied to a hardware unit + // (e.g., the callers into this scheduler will be persistent thread blocks). + // A work unit can encompass multiple output tiles worth of work (as will be the + // case for stream-K blocks). + // Since splitting is not required for data-parallel tiles, only one data-parallel unit + // is needed per data-parallel tile. + uint64_t dp_units = dp_tiles; + + // Number of k iterations computed by the stream-K units as a whole + uint64_t k_iter_sk_total = k_iter_per_tile * sk_tiles; + + // If there are stream-K tiles to compute and a sufficiently large number of k iterations + // across them, they will be covered by a single wave of persistent threadblocks. Thus, there + // will be as many work units as there are threadblocks in a single wave. + // + // When the total k iterations across stream-K tiles is too small to justify distributing + // across an entire wave of blocks, we instead distribute the iterations over a smaller + // set of blocks. + + // Calculate the number of stream-K units that would be needed if each stream-K unit + // computed the minimum allowable k iterations. Truncate this to be in units of clusters. + uint64_t min_sized_sk_units = (k_iter_sk_total / min_iters_per_sk_unit_); + min_sized_sk_units = (min_sized_sk_units / cute::size(cluster_shape)) * cute::size(cluster_shape); + + uint64_t sk_units = min(ctas_per_wave, min_sized_sk_units); + + if (sk_units == 0) { + // Short circuit to basic data-parallel decomposition + return get_params_basic( + underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, + 1, k_iter_per_tile, reduction_workspace); + } + + // If the number of stream-K units is a multiple of the number of stream-K tiles, then + // the problem can leverage a basic split-K decomposition for the stream-K tiles. + if (sk_tiles < sk_units && sk_units % sk_tiles == 0) { + // Short circuit to basic split-K decomposition + uint32_t sk_splits = static_cast(sk_units / sk_tiles); + return get_params_basic( + underlying_params, problem_blocks_m, problem_blocks_n, problem_blocks_l, cluster_shape, + sk_splits, k_iter_per_tile, reduction_workspace); + } + + // Number of k iterations computed per stream-K units + uint64_t k_iter_per_sk_unit = k_iter_sk_total / sk_units; + + // Number of stream-K units that need to compute extra iterations in order to cover + // the residual k iterations. This assumes that each such unit computes one additional + // iteration. + uint64_t sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units); + + // The division below is guaranteed to be exact because sk_big_units is guaranteed + // to be a multiple of cluster_size (cute::size(cluster_shape)). This is useful because + // it allows us to use a block's linearized cluster ID to determine whether it is + // a big block. The reasoning behind this guarnatee is explained as follows: + // sk_big_units = k_iter_sk_total - (k_iter_per_sk_unit * sk_units); + // + // - k_iter_sk_total is a multiple of cluster_size because it is the product + // of number of tail tiles and the number of k iterations per tile. Because + // both the number of output tiles and number of available SMs are rounded + // to be multiples of cluster shape, the number of tail tiles + // (output_tiles % avail_sms) is a multpile of cluster_size. + // + // - sk_units is a multiple of cluster_size because it is either blocks_per_wave + // or 0, and blocks_per_wave is a multiple of the cluster_size due to the grid-planning + // logic rounding to multiples of cluster dimensions + uint64_t sk_big_units_per_cluster = sk_big_units / cute::size(cluster_shape); + + return { + underlying_params.divmod_cluster_shape_major_, + underlying_params.divmod_cluster_shape_minor_, + underlying_params.divmod_batch_, + FastDivmodU64(problem_blocks_m * problem_blocks_n), // Static k-splitting divmod. Unused for stream-K. + underlying_params.divmod_cluster_blk_major_, + underlying_params.log_swizzle_size_, + static_cast(dp_units + sk_units), + underlying_params.raster_order_, + cluster_shape, + 1, // Static k-splitting factor. Unused for stream-K. + k_iter_per_tile, + static_cast(sk_big_units_per_cluster), + reduction_workspace, + sk_tiles, + static_cast(sk_units), + static_cast(k_iter_per_sk_unit) + }; + } + + CUTLASS_HOST_DEVICE + PersistentTileSchedulerSm90StreamK() { }; + + CUTLASS_HOST_DEVICE + PersistentTileSchedulerSm90StreamK(Params const& params_) : scheduler_params(params_) { + if (params_.raster_order_ == RasterOrder::AlongN) { + current_work_linear_idx_ = static_cast(int(blockIdx.x) + (int(blockIdx.y) * int(gridDim.x))); + } + else { + current_work_linear_idx_ = static_cast((int(blockIdx.x) * int(gridDim.y)) + int(blockIdx.y)); + } + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work() const { + return get_current_work_for_linear_idx(current_work_linear_idx_); + } + + CUTLASS_DEVICE + WorkTileInfo + get_current_work_for_linear_idx(uint64_t linear_idx) const { + if (linear_idx >= scheduler_params.units_per_problem_) { + // Invalid work. Return an empty result. + return {0, 0, 0, 0, false, 0}; + } + + // Determine whether this work unit is a data-parallel or stream-K work unit + bool is_stream_k_unit = linear_idx < scheduler_params.sk_units_; + + bool is_split_k = scheduler_params.splits_ > 1; + + // Bypass the stream-K scheduling logic for basic data-parallel or split-K work + if (is_split_k || !is_stream_k_unit) { + // The linearized ID space is in terms of work units, rather than tiles. However, + // to compute the correct block offset for a data-parallel tile, we must convert + // the current ID to the data-parallel tile it corresponds to. Each data-parallel + // unit maps to a single data-parallel tile, but each stream-K unit can map to more + // than one tile. Thus, we must offset the work-unit ID among the data-parallel units + // by the total number of output tiles that will be computed by stream-K units. + // + // The logic below also works for the split-K case, in which sk_units_ and sk_tiles_ + // are each 0. + uint64_t linear_work_idx = linear_idx - scheduler_params.sk_units_ + scheduler_params.sk_tiles_; + + // Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices + uint64_t work_idx_l, remainder; + scheduler_params.divmod_batch_(work_idx_l, remainder, linear_work_idx); + + uint64_t work_idx_k = 0; + if (is_split_k) { + scheduler_params.divmod_k_(work_idx_k, remainder, remainder); + } + + uint64_t cta_per_grid_dim, dontcare; + scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); + + auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( + cta_per_grid_dim, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cluster_blk_major_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); + + bool is_final_split = (work_idx_k == scheduler_params.splits_ - 1); + + uint32_t k_iter = scheduler_params.k_iter_per_tile_; + if (is_split_k) { + // Determine the number of iterations and starting iteration of this split. + // Doing so requires accounting for residual iterations, which are handled + // by the first big_units_ splits (with big_units_ = tiles % sm_count). + + // Offsets for "normal" units. No additional k iterations are performed, + // and big_units_ "big" units preceded us, each of which performed one + // additional iteration. Thus, we must increase our split starting offset + // by big_units_. + int additional_k_iter = 0; + int split_start_offset = scheduler_params.big_units_; + + if (work_idx_k < scheduler_params.big_units_) { + // Offsets for "big" units. One additional k iteration is performed, + // and each split preceding us was a big unit, so we must increase + // our split starting offset by our split ID (work_idx_k). + additional_k_iter = 1; + split_start_offset = work_idx_k; + } + + // Set up k iteration count and split starting iteration assuming the + // iteration space is evenly split. + k_iter /= scheduler_params.splits_; + work_idx_k *= k_iter; + + // Apply any fixup needed to handle residuals + work_idx_k += split_start_offset; + k_iter += additional_k_iter; + } + + return { + work_idx_m, + work_idx_n, + static_cast(work_idx_k), + static_cast(work_idx_l), + true, + scheduler_params.k_iter_per_tile_, + k_iter, + k_iter, // remaining iterations + is_final_split + }; + } + + // This is a stream-K work unit + WorkTileInfo work_tile_info; + set_stream_k_work(linear_idx, work_tile_info, /*new_unit = */ true); + return work_tile_info; + } + + // Returns whether the current work_tile_info passed in should continue to be used. This + // occurs only in the stream-K decomposition with stream-K work units, which encompass + // work over multiple output tiles. If the current work_tile_info should continue to be + // used, it is updated to advance to the next output tile it should cover. + CUTLASS_DEVICE + bool + continue_current_work(WorkTileInfo& work_tile_info) const { + work_tile_info.k_tile_remaining -= work_tile_info.k_tile_count; + + if (work_tile_info.k_tile_remaining == 0) { + return false; + } + + set_stream_k_work(current_work_linear_idx_, work_tile_info, /* new_unit = */ false); + return true; + } + + CUTLASS_DEVICE + void + advance_to_next_work(uint32_t advance_count = 1) { + current_work_linear_idx_ += static_cast(int(gridDim.x) * int(gridDim.y) * int(gridDim.z)) * advance_count; + } + + // Given the inputs, computes the total number of output blocks this problem will compute over + // Note that this is only the logical size of our grid, not the physical grid we will actually launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_tiled_cta_shape_mnl(ProblemShape problem_shape_mnkl, TileShape cta_shape, ClusterShape cluster_shape) { + return UnderlyingScheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, cta_shape, cluster_shape); + } + + // Given the cluster shape, computes the physical grid we should launch. + template + CUTLASS_HOST_DEVICE static + dim3 + get_grid_shape( + ProblemShape problem_shape, + TileShape tile_shape, + ClusterShape cluster_shape, + KernelHardwareInfo hw_info, + Arguments arguments) { + + UnderlyingArguments underlying_args; + underlying_args.max_swizzle_size = 1; + // Call into the underlying get_grid_shape method, but do not allow the grid shape returned + // to be truncated based on the number of output tiles in the problem. + return UnderlyingScheduler::get_grid_shape( + problem_shape, + tile_shape, + cluster_shape, + hw_info, + underlying_args, + /*truncate_by_problem_size=*/false); + } + + // Performs the reduction across splits for a given output tile. + template + CUTLASS_DEVICE + static void + fixup( + Params const& params, + WorkTileInfo const& work_tile_info, + FrgTensorC& accumulators, + uint32_t num_barriers, + uint32_t barrier_idx) { + + using ElementAccumulator = typename FrgTensorC::value_type; + + using BarrierManager = NamedBarrierManager; + + if (work_tile_info.k_tile_count == params.k_iter_per_tile_) { + // Fixup is not needed for data-parallel tiles + return; + } + + auto tile_idx = output_tile_index(params, work_tile_info); + + // Index of the lock on which to wait + auto lock_idx = (tile_idx * num_barriers) + barrier_idx; + + // Reductions use BlockStripedReduce with a width of BarrierManager::ThreadCount under the hood. + // Thus, the start of the reduction space is the same across all threads in a warp group. + int reduction_offset = + (cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}) * tile_idx) + + (size(accumulators) * barrier_idx * BarrierManager::ThreadCount); + + ElementAccumulator* group_reduction_workspace = reinterpret_cast(params.reduction_workspace_) + reduction_offset; + + using AccumulatorArrayT = Array; + using BlockStripedReduceT = BlockStripedReduce; + + AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); + AccumulatorArrayT* accumulator_array = reinterpret_cast(&accumulators); + + int barrier_group_thread_idx = threadIdx.x % BarrierManager::ThreadCount; + + // The number of tiles for which reduction is required is either: + // (a) the total number of output tiles (in the case of split-K) + // (b) the number of stream-K tiles + // To calcualte the the total number of output tiles in the split-K case, we + // note that, in the split-K case, the units_per_problem_ member of Params will be + // the total number of output tiles multiplied by the number of splits. + auto reduction_tiles = params.splits_ > 1 ? (params.units_per_problem_ / params.splits_) : params.sk_tiles_; + auto reduction_workspace_size = get_reduction_workspace_size(reduction_tiles); + BarrierType* lock_workspace = reinterpret_cast( + reinterpret_cast(params.reduction_workspace_) + reduction_workspace_size); + + if (!work_tile_info.is_final_split) { + if (work_tile_info.K_idx == 0) { + // First peer initializes the workspace partials + BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); + } + else { + // Wait until the preceding split added its accumulators + BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + + // Perform reduction in workspace + BlockStripedReduceT::reduce(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); + } + + // Signal our arrival + BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.k_tile_count); + } + else { + // Wait until the preceding split added its accumulators + BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); + + // The block computing the final split for the tile adds previously-reduced partials + // to its accumulators and computes the epilogue. + BlockStripedReduceT::load_add(*accumulator_array, reduction_workspace_array, barrier_group_thread_idx); + } + } + + // Splits an input tensor with MxK according to the splitting configuration specified by work_tile_info + template + CUTLASS_DEVICE + static auto + split_MK(cute::Tensor const& tensor, WorkTileInfo const& work_tile_info) { + return split(tensor, work_tile_info); + } + + // Splits an input tensor with NxK tiles according to the splitting configuration specified by work_tile_info + template + CUTLASS_DEVICE + static auto + split_NK(cute::Tensor const& tensor, WorkTileInfo const& work_tile_info) { + return split(tensor, work_tile_info); + } + + // Returns whether the block assigned this work should compute the epilogue for the corresponding + // output tile. For the case of stream-K, this should only occur if the work is marked as the final split. + CUTLASS_HOST_DEVICE + static bool + compute_epilogue(WorkTileInfo const& work_tile_info) { + return work_tile_info.is_final_split; + } + + // Returns the linearized index of the output tile corresponding to the tile with offset [L, M, K] + CUTLASS_HOST_DEVICE + static int + output_tile_index(Params const& params, WorkTileInfo const& work_tile_info) { + if (params.splits_ > 1) { + auto tiles_mn = params.divmod_batch_.divisor / params.splits_; + if (params.raster_order_ == RasterOrder::AlongN) { + return + (tiles_mn * work_tile_info.L_idx) + + (params.divmod_cluster_shape_major_.divisor * + params.divmod_cluster_blk_major_.divisor * work_tile_info.M_idx) + + work_tile_info.N_idx; + } + else { + return + (tiles_mn * work_tile_info.L_idx) + + (params.divmod_cluster_shape_major_.divisor * + params.divmod_cluster_blk_major_.divisor * work_tile_info.N_idx) + + work_tile_info.M_idx; + } + } + else { + uint64_t cta_per_grid_dim; + uint64_t cluster_dim_idx; + if (params.raster_order_ == RasterOrder::AlongN) { + uint64_t block_idx_m = (work_tile_info.M_idx - blockIdx.x) / gridDim.x; + uint64_t block_idx_n = work_tile_info.N_idx; + cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * + params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n; + cluster_dim_idx = blockIdx.x; + } + else { + uint64_t block_idx_m = work_tile_info.M_idx; + uint64_t block_idx_n = (work_tile_info.N_idx - blockIdx.y) / gridDim.y; + cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * + params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m; + cluster_dim_idx = blockIdx.y; + } + + uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim; + return params.divmod_batch_.divisor * work_tile_info.L_idx + tile_in_batch + cluster_dim_idx; + } + } + + template + static int + get_workspace_size( + Arguments const& args, + ProblemShape problem_shape, + KernelHardwareInfo const& hw_info, + uint32_t mma_warp_groups) { + + int barrier_workspace_size = 0; + int reduction_workspace_size = 0; + + get_workspace_component_sizes( + args, problem_shape, barrier_workspace_size, reduction_workspace_size, hw_info, mma_warp_groups); + + return barrier_workspace_size + reduction_workspace_size; + } + + template + static cutlass::Status + initialize_workspace( + Arguments const& args, + void* workspace, + cudaStream_t stream, + ProblemShape const& problem_shape, + KernelHardwareInfo const& hw_info, + uint32_t mma_warp_groups) { + + #if !defined(__CUDACC_RTC__) + int barrier_workspace_size = 0; + int reduction_workspace_size = 0; + + get_workspace_component_sizes( + args, problem_shape, barrier_workspace_size, reduction_workspace_size, hw_info, mma_warp_groups); + + if (barrier_workspace_size > 0) { + if (workspace == nullptr) { + return Status::kErrorWorkspaceNull; + } + + // Only the barrier workspace needs to be cleared for stream-K. + // Barrier workspace follows reduction workspace. + uint8_t* barrier_workspace = reinterpret_cast(workspace) + reduction_workspace_size; + return zero_workspace(static_cast(barrier_workspace), barrier_workspace_size, stream); + } + + return Status::kSuccess; + #endif + } + + template + CUTLASS_HOST_DEVICE + static int + get_work_k_tile_count(WorkTileInfo const& work_tile_info, ProblemShape, TileShape) { + return work_tile_info.k_tile_count; + } + +private: + // Splits a tensor using the splitting configuration specified by work_tile_info using + // a MN shape detemined by TileDim0. + template + CUTLASS_DEVICE + static auto + split(cute::Tensor const& tensor, WorkTileInfo const& work_tile_info) { + using namespace cute; + + // Divide input tensor into `splits` chunks along the k dimension + auto div_shape = make_shape(size(TileShape{}), size<2>(TileShape{}), work_tile_info.splits); + auto split = zipped_divide(tensor, div_shape); + + // Index into the split tensor at the work tile's split index + auto indexed = split(make_coord(make_coord(_, _, work_tile_info.K_idx), make_coord(0, 0, _))); + + // Construct a layout for the indexed tensor. The main purpose of this new layout is to + // override the k extent to support cases in which the split computes a number of iterations + // not equal to total_tile_k_iter / splits. A common example of this is in stream-K is when a + // unit computes the final 20 of the total 32 k iterations of the output tile. In this case, + // set splits = 32 and the split index (K_idx) to 11. The zipped divide above results in each + // of the splits computing only one k iteration. + auto overridden_shape = make_shape(size<0>(indexed.layout()), size<1>(indexed.layout()), work_tile_info.k_tile_count); + auto layout = make_layout(overridden_shape, tensor.stride()); + + return make_tensor(indexed.data(), layout); + } + + // Returns the number of stream-K tiles that will be computed amongst `output_tiles` total + // output tiles on a device with `ctas_per_wave` CTAs in each wave. + static uint32_t + get_num_sk_tiles(uint64_t output_tiles, uint64_t ctas_per_wave) { + uint32_t full_waves = static_cast(output_tiles / ctas_per_wave); + uint32_t total_waves = static_cast((output_tiles + ctas_per_wave - 1) / ctas_per_wave); + + if (full_waves == total_waves) { + // No quantization. All tiles will be data-parallel tiles. + return 0; + } + + // + // The final wave is not full. Perform some stream-K work. + // + + // Rudimentary heuristic: prefer data-parallel decomposition if we have more than + // one wave and the tail wave is more than half full. This is subject to change. + if (full_waves != 0) { + uint64_t tail_tiles = output_tiles - (full_waves * ctas_per_wave); + if (tail_tiles >= (ctas_per_wave / 2)) { + return 0; + } + } + + // If there is wave quantization, assign the first two waves worth of tiles to be + // covered by stream-K work and the remainder to be data-parallel. Since we know + // that full_waves == total_waves - 1 in this case, the number of data-parallel + // waves is simply full_waves-1 (unless full_waves == 0). + uint32_t dp_waves = full_waves > 0 ? full_waves - 1 : 0; + + uint64_t dp_tiles = dp_waves * ctas_per_wave; + return static_cast(output_tiles - dp_tiles); + } + + // Calculates the size of the workspace needed for holding reduction barriers + CUTLASS_HOST_DEVICE + static int + get_barrier_workspace_size(uint64_t num_tiles, uint32_t mma_warp_groups) { + auto workspace_bits = num_tiles * mma_warp_groups * sizeof_bits::value; + return bits_to_bytes(static_cast(workspace_bits)); + } + + // Calculates the size of the workspace needed for holding partial outputs from splits + template + CUTLASS_HOST_DEVICE + static int + get_reduction_workspace_size(uint64_t num_tiles) { + auto output_tile_size = cute::size<0>(TileShape{}) * cute::size<1>(TileShape{}); + auto workspace_bits = sizeof_bits::value * output_tile_size * num_tiles; + return bits_to_bytes(static_cast(workspace_bits)); + } + + template + static void + get_workspace_component_sizes( + Arguments const& args, + ProblemShape problem_shape, + int& barrier_workspace_size, + int& reduction_workspace_size, + KernelHardwareInfo const& hw_info, + uint32_t mma_warp_groups) { + + // Workspace is needed only for output tiles that will be split. Thus, we first determine the number + // of output tiles that will be split, and then calculate the workspace needed to cover these. + + auto problem_shape_mnkl = cute::append<4>(problem_shape, 1); + + ClusterShape cluster_shape; + auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = get_tiled_cta_shape_mnl( + problem_shape_mnkl, TileShape{}, cluster_shape); + uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; + + if (args.splits > 1) { + // Basic split-K variant requires workspace for all output tiles + barrier_workspace_size = get_barrier_workspace_size(output_tiles, mma_warp_groups); + reduction_workspace_size = get_reduction_workspace_size(output_tiles); + } + else { + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + dim3 grid = get_grid_shape(problem_shape_mnkl, TileShape{}, cluster_shape, {0, sm_count}, args); + uint64_t ctas_per_wave = grid.x * grid.y; + uint32_t sk_tiles = get_num_sk_tiles(output_tiles, ctas_per_wave); + + barrier_workspace_size = get_barrier_workspace_size(sk_tiles, mma_warp_groups); + reduction_workspace_size = get_reduction_workspace_size(sk_tiles); + } + } + + // Constructs parameters for either a basic data-parallel or basic split-K decomposition of the problem + static Params + get_params_basic( + UnderlyingParams const& underlying_params, + uint32_t blocks_m, + uint32_t blocks_n, + uint32_t blocks_l, + ClusterShape cluster_shape, + uint32_t splits, + uint32_t k_iter_per_tile, + void* reduction_workspace) { + + uint32_t big_units = k_iter_per_tile % splits; + + return { + underlying_params.divmod_cluster_shape_major_, + underlying_params.divmod_cluster_shape_minor_, + FastDivmodU64(blocks_m * blocks_n * splits), + FastDivmodU64(blocks_m * blocks_n), + underlying_params.divmod_cluster_blk_major_, + underlying_params.log_swizzle_size_, + blocks_m * blocks_n * blocks_l * splits, + underlying_params.raster_order_, + cluster_shape, + splits, + k_iter_per_tile, + big_units, + reduction_workspace + }; + } + + // Sets the current stream-K work to compute within work_tile_info. If new_unit is true, work_tile_info + // is populated as a new unit of work. Otherwise, state existing in work_tile_info (e.g., remaining + // iterations) is used to find the next tile in the current work unit. + CUTLASS_DEVICE + void + set_stream_k_work(uint64_t linear_idx, WorkTileInfo& work_tile_info, bool new_unit) const { + // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K + // threadblock individually. For the most part, the set of K iterations corresponding to stream-K + // work was divided amongst stream-K threadblocks, and a threadblock determined which tile + // it would compute a (potentially-partial) output tile for based on the space of k iterations + // assigned to it. This often results in stream-K threadblocks processing tiles with different + // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the + // (generally few) waves of threadblocks assigned to compute stream-K work. + // + // With the introduction of threadblock clusters, there is additional benefit to maintaining + // locality in the K dimension: shared portions of operands can be multicasted to threadblocks + // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to + // threadblocks respects the ability to perform multicasting. + // + // To do so, we divide up the linearized stream-K units into clusters and share the same K + // offsets for work within clusters. + auto cluster_linear_work_idx = linear_idx / size(scheduler_params.cluster_shape_); + + // Determine the starting k iteration computed by this stream-K work unit + uint32_t unit_iter_start = scheduler_params.k_iter_per_sk_unit_ * cluster_linear_work_idx; + + // Adjust the starting position and number of k iterations for "big units," which + // compute one extra iteration. These are the first big_units_ units in the + // linearized ID space. + bool is_big_unit = cluster_linear_work_idx < scheduler_params.big_units_; + if (is_big_unit) { + // Since the "big units" are the first units in the linearized ID space, each + // of the units preceding this big unit computed one extra iteration. Thus, + // we must offset our start iteration by the number of units that precede + // the current unit in the linearized ID space. + unit_iter_start += cluster_linear_work_idx; + } else { + // Increment by one for each of the big clusters (since all big units precede this unit) + unit_iter_start += scheduler_params.big_units_; + } + + uint32_t unit_iters; + if (new_unit) { + unit_iters = scheduler_params.k_iter_per_sk_unit_; + + // Only adjust iteration count for big unit if we are initializing this + // work unit. For existing work units, the extra iteration for big units + // has already been accounted for in k_iter_reamaining + if (is_big_unit) { + ++unit_iters; + } + } + else { + unit_iters = work_tile_info.k_tile_remaining; + } + + // Find the output tile corresponding to the final k iteration covered by this + // work unit. Stream-K work units will work backwards in terms of the tiles they + // are responsible computing. This is beneficial because the final (partial) + // tile computed by a stream-K block is typically the beginning of the output + // tile, while the beginning (partial) tile is typically the ending of another + // output tile. Since ending portions of an output tile must reduce across + // other work units computing portions of that output tile, it is preferable + // for them to be computed later, so as to reduce the likelihood of blocking + // on other work. + uint32_t unit_iter_end = unit_iter_start + unit_iters - 1; + uint32_t true_tile_id = unit_iter_end / scheduler_params.k_iter_per_tile_; + uint32_t true_tile_iter_start = true_tile_id * scheduler_params.k_iter_per_tile_; + uint32_t true_tile_iter_end = true_tile_iter_start + scheduler_params.k_iter_per_tile_; + + // Bring the linearized tile ID back into the space of tiles, rather than clusters + true_tile_id *= size(scheduler_params.cluster_shape_); + + auto cluster_dim0 = cute::size<0>(scheduler_params.cluster_shape_); + auto cluster_dim1 = cute::size<1>(scheduler_params.cluster_shape_); + + // The final linearized tile ID is in units of the cluster dimension over which we rasterize. + if (scheduler_params.raster_order_ == RasterOrder::AlongN) { + true_tile_id += (blockIdx.y % cluster_dim1) * cluster_dim0; + } + else { + true_tile_id += (blockIdx.x % cluster_dim0) * cluster_dim1; + } + + // The unit's starting k iteration in the current tile is either the starting + // iteration for the tile as a whole, or the starting k iteration for the unit + // as a whole (if the latter is greater than the former). + uint32_t tile_iter_start = max(true_tile_iter_start, unit_iter_start); + + // Similarly, the unit's ending k iteration (exclusive) is either the end of + // the current tile it is assigned, or the ending iteration of the unit as a whole + // (if the latter is less than the former). + uint32_t tile_iter_end = min(true_tile_iter_end, unit_iter_end + 1); + + uint32_t tile_iters = tile_iter_end - tile_iter_start; + + uint64_t work_idx_l, remainder; + scheduler_params.divmod_batch_(work_idx_l, remainder, true_tile_id); + + uint64_t cta_per_grid_dim, dontcare; + scheduler_params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); + + + auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( + cta_per_grid_dim, + scheduler_params.divmod_cluster_shape_major_, + scheduler_params.divmod_cluster_shape_minor_, + scheduler_params.divmod_cluster_blk_major_, + scheduler_params.log_swizzle_size_, + scheduler_params.raster_order_); + + // + // Update the work_tile_info + // + + // Set the M, N, and L block offsets + work_tile_info.M_idx = work_idx_m; + work_tile_info.N_idx = work_idx_n; + work_tile_info.L_idx = static_cast(work_idx_l); + + // Set the k offset to be the starting k iteration for this tile + work_tile_info.K_idx = static_cast(tile_iter_start - true_tile_iter_start); + + // Set the split count to be the number of k iterations in the tile + work_tile_info.splits = scheduler_params.k_iter_per_tile_; + + // Any checks for invalid work units should be done prior to this call + work_tile_info.is_valid_tile = true; + + work_tile_info.k_tile_count = tile_iters; + work_tile_info.k_tile_remaining = unit_iters; + + // Compute the epilogue if this unit of work contains the ending k iteration for + // the output tile in question + work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end); + } +}; + +} // namespace cutlass::gemm::kernel::detail diff --git a/include/cutlass/gemm/kernel/sparse_gemm.h b/include/cutlass/gemm/kernel/sparse_gemm.h index eba95aad4c..1964fba8bc 100644 --- a/include/cutlass/gemm/kernel/sparse_gemm.h +++ b/include/cutlass/gemm/kernel/sparse_gemm.h @@ -277,7 +277,7 @@ struct SparseGemm { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/gemm/kernel/symm_universal.h b/include/cutlass/gemm/kernel/symm_universal.h index 47e7035abe..f05cf7df9b 100755 --- a/include/cutlass/gemm/kernel/symm_universal.h +++ b/include/cutlass/gemm/kernel/symm_universal.h @@ -415,7 +415,7 @@ struct SymmUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/kernel/tile_scheduler.hpp b/include/cutlass/gemm/kernel/tile_scheduler.hpp new file mode 100644 index 0000000000..a81460e4f6 --- /dev/null +++ b/include/cutlass/gemm/kernel/tile_scheduler.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +/*! \file + \brief Utilities for selecting default tile schedulers +*/ + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm { + +//////////////////////////////////////////////////////////////////////////////// + +// +// Tags for specifying tile schedulers +// + +struct PersistentScheduler { }; + +struct StreamKScheduler { }; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel::detail { + +// +// Selectors mapping tile scheduler tag and arch tag to a tile scheduler class +// + +template < + class TileSchedulerTag, + class ArchTag, + class TileShape, + class ClusterShape +> +struct TileSchedulerSelector { + static_assert(cutlass::detail::dependent_false, + "Could not select a tile scheduler for given parameters."); +}; + +template < + class ArchTag, + class TileShape, + class ClusterShape +> +struct TileSchedulerSelector< + PersistentScheduler, + ArchTag, + TileShape, + ClusterShape + > { + using Scheduler = PersistentTileSchedulerSm90; +}; + +// Default (void) for Sm90 maps to PersistentTileSchedulerSm90 +template < + class ArchTag, + class TileShape, + class ClusterShape +> +struct TileSchedulerSelector< + void, + ArchTag, + TileShape, + ClusterShape + > { + using Scheduler = typename TileSchedulerSelector< + PersistentScheduler, + ArchTag, + TileShape, + ClusterShape + >::Scheduler; +}; + +template < + class TileShape, + class ClusterShape +> +struct TileSchedulerSelector< + StreamKScheduler, + arch::Sm90, + TileShape, + ClusterShape + > { + using Scheduler = PersistentTileSchedulerSm90StreamK; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel::detail + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/kernel/trmm_universal.h b/include/cutlass/gemm/kernel/trmm_universal.h index 7ba223bbb4..bca9450b8e 100644 --- a/include/cutlass/gemm/kernel/trmm_universal.h +++ b/include/cutlass/gemm/kernel/trmm_universal.h @@ -380,7 +380,7 @@ struct TrmmUniversal { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx(); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h index ad232fce66..39a6454d00 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h @@ -2013,7 +2013,7 @@ struct DefaultMmaCore= 0; --n) { @@ -536,7 +535,6 @@ class MmaComplexTensorOp< CUTLASS_DEVICE void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { - //TODO: Implement this dst_A = A; dst_B = B; } @@ -1161,8 +1159,6 @@ class MmaComplexTensorOp< ///////////////////////////////////////////////////////////////////////////////////////////////// -// TODO - partial specializations of real*complex and complex*real - ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h index d872012b40..02fd4c077f 100644 --- a/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h +++ b/include/cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h @@ -530,7 +530,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -542,7 +541,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -757,7 +755,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -769,7 +766,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1588,7 +1584,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1600,7 +1595,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1816,7 +1810,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1828,7 +1821,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. diff --git a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h index 00760a6be7..31a661f723 100644 --- a/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_gaussian_complex_tensor_op.h @@ -350,7 +350,6 @@ class MmaGaussianComplexTensorOp< CUTLASS_DEVICE void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { - //TODO: Implement this dst_A = A; dst_B = B; } diff --git a/include/cutlass/gemm/warp/mma_simt.h b/include/cutlass/gemm/warp/mma_simt.h index 9790792367..ecde134a9d 100644 --- a/include/cutlass/gemm/warp/mma_simt.h +++ b/include/cutlass/gemm/warp/mma_simt.h @@ -251,7 +251,6 @@ class MmaSimt { CUTLASS_DEVICE void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { - //TODO: Implement this dst_A = A; dst_B = B; } diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index 54f194fc30..ac042cbc71 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -1031,7 +1031,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1043,7 +1042,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1262,7 +1260,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1274,7 +1271,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -2060,7 +2056,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO assert(0); } @@ -2073,7 +2068,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO assert(0); } @@ -2300,7 +2294,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO assert(0); } @@ -2313,7 +2306,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO assert(0); } diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h index bf192e6afc..b79b43e728 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h @@ -857,7 +857,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -869,7 +868,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1081,7 +1079,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1093,7 +1090,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1987,7 +1983,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO assert(0); } @@ -2000,7 +1995,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO assert(0); } @@ -2215,7 +2209,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO assert(0); } @@ -2228,7 +2221,6 @@ class MmaVoltaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO assert(0); } diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h index 29cc3d9f3e..beeff23830 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h @@ -546,7 +546,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -558,7 +557,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -774,7 +772,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -786,7 +783,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -999,7 +995,8 @@ class MmaTensorOpMultiplicandTileIterator< CUTLASS_DEVICE MmaTensorOpMultiplicandTileIterator &add_tile_offset_negative(TensorCoord const &tile_offset) { - add_tile_offset(tile_offset); // TODO fix this if it becomes an issue during warp it reset + // TODO: fix this if it becomes an issue during warp it reset + add_tile_offset(tile_offset); return *this; } @@ -1334,7 +1331,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1346,7 +1342,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1567,7 +1562,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -1579,7 +1573,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -2170,7 +2163,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -2182,7 +2174,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -2399,7 +2390,6 @@ class MmaTensorOpMultiplicandTileIterator< Fragment &frag, /// loads a tile with a logical offset in units of whole tiles TensorCoord const &tile_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. @@ -2411,7 +2401,6 @@ class MmaTensorOpMultiplicandTileIterator< TensorCoord const &tile_offset, /// loads a tile with a logical offset AND a pointer offset Index pointer_offset) const { - // TODO } /// Loads a fragment from memory with logical offset in units of whole tiles. diff --git a/include/cutlass/gemm/warp/scale_bias_tile_iterator.h b/include/cutlass/gemm/warp/scale_bias_tile_iterator.h index 9c9b90bcc8..aebeea7900 100644 --- a/include/cutlass/gemm/warp/scale_bias_tile_iterator.h +++ b/include/cutlass/gemm/warp/scale_bias_tile_iterator.h @@ -522,7 +522,6 @@ class ScaleBiasTileIterator +struct GemmShape { + static int const kM = M; + static int const kN = N; + static int const kK = K; + + static int const kMN = M * N; + static int const kMK = M * K; + static int const kKN = N * K; + static int const kMNK = M * N * K; + + static int const kCount = kMNK; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<3> toCoord() { + return make_Coord(kM, kN, kK); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Type alias of the transpose of a GemmShape +template < + /// concept: GemmShape + typename Shape +> +using GemmShapeTranspose = GemmShape; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// GemmCoord is a structure derived from Coord<3> that specifies a location within the +/// coordinate space of a GEMM problem. +struct GemmCoord : public Coord<3, int> { + + /// Integer-valued index + typedef int Index; + + /// Base type is a Coord of rank=3 + typedef Coord<3, Index> Base; + + /// GEMM M dimension - rows of the output C matrix + static int const kM = 0; + + /// GEMM N dimension - columns of the output C matrix + static int const kN = 1; + + /// GEMM K dimension - inner dimension of the GEMM problem + static int const kK = 2; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + GemmCoord() { } + + /// Constructs from Coord<3> and a batch + CUTLASS_HOST_DEVICE + GemmCoord(Coord<3, Index> const& coord): Base(make_Coord(coord[0], coord[1], coord[2])) { } + + /// Helper to construct from a K, N, M, batch variables + CUTLASS_HOST_DEVICE + GemmCoord(Index m, Index n, Index k): Base(make_Coord(m, n, k)) { } + + /// Returns the GEMM M coordinate + CUTLASS_HOST_DEVICE + Index const& m() const { return this->at(kM); } + + /// Returns reference to the GEMM M coordinate + CUTLASS_HOST_DEVICE + Index & m() { return this->at(kM); } + + /// Returns the GEMM N coordinate + CUTLASS_HOST_DEVICE + Index const& n() const { return this->at(kN); } + + /// Returns reference to the GEMM N coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the GEMM K coordinate + CUTLASS_HOST_DEVICE + Index const& k() const { return this->at(kK); } + + /// Returns reference to the GEMM K coordinate + CUTLASS_HOST_DEVICE + Index & k() { return this->at(kK); } + + /// Obtains a Coord<3> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<3> mnk() const { + return make_Coord(m(), n(), k()); + } + + /// Obtains a Coord<3> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<3> knm() const { + return make_Coord(k(), n(), m()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> nm() const { + return make_Coord(n(), m()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> mn() const { + return make_Coord(m(), n()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> mk() const { + return make_Coord(m(), k()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> km() const { + return make_Coord(k(), m()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> nk() const { + return make_Coord(n(), k()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> kn() const { + return make_Coord(k(), n()); + } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + GemmCoord operator+(Base const& b) const { + return GemmCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + GemmCoord operator-(Base const& b) const { + return GemmCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + GemmCoord operator*(Base const& b) const { + return GemmCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + GemmCoord operator/(Base const& b) const { + return GemmCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + GemmCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + GemmCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + GemmCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + GemmCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// BatchedGemmCoord is a structure derived from Coord<4> that specifies a location within the +/// coordinate space of a batched GEMM problem. +struct BatchedGemmCoord : public Coord<4, int> { + + /// Integer-valued index + typedef int Index; + + /// Base type is a Coord of rank=4 + typedef Coord<4, Index> Base; + + /// GEMM M dimension - rows of the output C matrix + static int const kM = 0; + + /// GEMM N dimension - columns of the output C matrix + static int const kN = 1; + + /// GEMM K dimension - inner dimension of the GEMM problem + static int const kK = 2; + + /// GEMM Batch dimension - inner dimension of the GEMM problem + static int const kBatch = 3; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + BatchedGemmCoord() { } + + /// Constructs from Coord<4> + CUTLASS_HOST_DEVICE + BatchedGemmCoord(Base const& coord): Base(coord) { } + + /// Helper to construct from a K, N, M, and batch variables + CUTLASS_HOST_DEVICE + BatchedGemmCoord(Index m, Index n, Index k, Index b): Base(make_Coord(m, n, k, b)) { } + + /// Returns the GEMM M coordinate + CUTLASS_HOST_DEVICE + Index const& m() const { return this->at(kM); } + + /// Returns reference to the GEMM M coordinate + CUTLASS_HOST_DEVICE + Index & m() { return this->at(kM); } + + /// Returns the GEMM N coordinate + CUTLASS_HOST_DEVICE + Index const& n() const { return this->at(kN); } + + /// Returns reference to the GEMM N coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the GEMM K coordinate + CUTLASS_HOST_DEVICE + Index const& k() const { return this->at(kK); } + + /// Returns reference to the GEMM K coordinate + CUTLASS_HOST_DEVICE + Index & k() { return this->at(kK); } + + /// Returns the GEMM batch coordinate + CUTLASS_HOST_DEVICE + Index const& batch() const { return this->at(kBatch); } + + /// Returns reference to the GEMM batch coordinate + CUTLASS_HOST_DEVICE + Index & batch() { return this->at(kBatch); } + + /// Obtains a GemmCoord from BatchedGemmCoord + CUTLASS_HOST_DEVICE + GemmCoord mnk() const { + return GemmCoord(m(), n(), k()); + } + + /// Obtains a Coord<4> from BatchedGemmCoord + CUTLASS_HOST_DEVICE + Coord<4> mnkb() const { + return make_Coord(m(), n(), k(), batch()); + } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + BatchedGemmCoord operator+(Base const& b) const { + return BatchedGemmCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + BatchedGemmCoord operator-(Base const& b) const { + return BatchedGemmCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + BatchedGemmCoord operator*(Base const& b) const { + return BatchedGemmCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + BatchedGemmCoord operator/(Base const& b) const { + return BatchedGemmCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + BatchedGemmCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + BatchedGemmCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + BatchedGemmCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + BatchedGemmCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/kernel_hardware_info.hpp b/include/cutlass/kernel_hardware_info.hpp index 586ef723a0..680036dff6 100644 --- a/include/cutlass/kernel_hardware_info.hpp +++ b/include/cutlass/kernel_hardware_info.hpp @@ -30,9 +30,11 @@ **************************************************************************************************/ #pragma once +#if !defined(__CUDACC_RTC__) #include "cuda_runtime.h" #include "cutlass/trace.h" +#endif namespace cutlass { @@ -47,6 +49,7 @@ struct KernelHardwareInfo { // Methods // +#if !defined(__CUDACC_RTC__) static int query_device_multiprocessor_count(int device_id = 0) { cudaError_t result = cudaGetDevice(&device_id); @@ -67,6 +70,7 @@ struct KernelHardwareInfo { } return multiprocessor_count; } +#endif }; } // namespace cutlass diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index c9e9c31cb5..ae84ff80a0 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -39,8 +39,6 @@ */ #pragma once -#include "cute/layout.hpp" - #include "cutlass/cutlass.h" #include "cutlass/fast_math.h" #include "cutlass/matrix_coord.h" @@ -145,15 +143,6 @@ class RowMajor { LongIndex capacity(MatrixCoord const &extent) const { return LongIndex(extent.row()) * LongIndex(stride_[0]); } - - CUTLASS_HOST_DEVICE - cute::Layout, cute::Stride > > - to_cute_layout(MatrixCoord const &extent) const { - return cute::Layout, cute::Stride > >{ - {extent[0], extent[1]}, - {stride(0), cute::Int<1>{}} - }; - } }; /// Mapping function for column-major matrices. @@ -247,15 +236,6 @@ class ColumnMajor { LongIndex capacity(MatrixCoord const &extent) const { return LongIndex(extent.column()) * LongIndex(stride_[0]); } - - CUTLASS_HOST_DEVICE - cute::Layout, cute::Stride< cute::Int<1>, int64_t> > - to_cute_layout(MatrixCoord const &extent) const { - return cute::Layout, cute::Stride, int64_t> >{ - {extent[0], extent[1]}, - {cute::Int<1>{}, stride(0)} - }; - } }; /// Mapping function for interleaved matrices. Matrix is structured @@ -558,7 +538,6 @@ struct ContiguousMatrix { /// Inverse of layout function, mapping linear offset to logical coordinate CUTLASS_HOST_DEVICE MatrixCoord inverse(LongIndex offset) const { - // TODO return MatrixCoord(0, 0); } @@ -709,7 +688,6 @@ struct AffineRankN { /// Inverse of layout function, mapping linear offset to logical coordinate CUTLASS_HOST_DEVICE TensorCoord inverse(LongIndex offset) const { - // TODO return TensorCoord(); } @@ -818,7 +796,6 @@ struct AffineRank2ColumnMajor { /// Inverse of layout function, mapping linear offset to logical coordinate CUTLASS_HOST_DEVICE MatrixCoord inverse(LongIndex offset) const { - // TODO return MatrixCoord(0, 0); } @@ -924,7 +901,6 @@ struct AffineRank2RowMajor { /// Inverse of layout function, mapping linear offset to logical coordinate CUTLASS_HOST_DEVICE MatrixCoord inverse(LongIndex offset) const { - // TODO return MatrixCoord(0, 0); } @@ -1074,7 +1050,6 @@ struct ColumnMajorBlockLinear { CUTLASS_HOST_DEVICE MatrixCoord inverse(LongIndex offset) const { - // TODO return MatrixCoord(0, 0); } @@ -1174,7 +1149,6 @@ struct RowMajorBlockLinear { /// Inverse of layout function, mapping linear offset to logical coordinate CUTLASS_HOST_DEVICE MatrixCoord inverse(LongIndex offset) const { - // TODO return MatrixCoord(0, 0); } diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h index 73c8170710..8e1f4ceeaa 100644 --- a/include/cutlass/layout/permute.h +++ b/include/cutlass/layout/permute.h @@ -111,7 +111,7 @@ struct InversePermute { /// Helper trait to detect if permute operation is a noop template -bool constexpr is_trivial_permute = platform::is_same::value; +inline bool constexpr is_trivial_permute = platform::is_same::value; ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -383,7 +383,6 @@ class Tensor4DPermuteBMM0213RowMajorInverse : public PermuteBase { // The batch index for BMM Index BMM_batch_idx = blockIdx.z; - // TODO: semantics of the original Tensor4DPermuteBMM0213 are unclear. // The following assumes grouping [(D0)->batch, (D2)->row, (D1,D3)->col] Index l = coord.column() % D3_; Index j = coord.column() / D3_; diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 29ac570ce5..0f10e865fd 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -60,6 +60,9 @@ namespace layout { // ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Tag used for 3-D NWC tensors for 1D conv, only used in 3.x API +class TensorNWC {}; + /// Mapping function for 4-D NHWC tensors. class TensorNHWC { public: diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 0ba84c74e7..7ee6e03c1c 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -55,6 +55,7 @@ enum class FloatRoundStyle { round_indeterminate, ///< rounding mode unknown round_toward_zero, ///< round toward zero round_to_nearest, ///< round to nearest even + round_to_nearest_satfinite, ///< round to nearest even, capping value to min and max of destination type round_toward_infinity, ///< round toward infinity round_toward_neg_infinity, ///< round toward negative infinity round_half_ulp_truncate, ///< add 0.5ulp to integer representation then round toward zero @@ -774,8 +775,7 @@ struct NumericArrayConverter { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - if( platform::is_same::value ) - { + if (platform::is_same::value) { result[i] = convert_(s[i]); } else { // conjugate result[i] = conj(convert_(s[i])); @@ -1000,8 +1000,6 @@ struct NumericArrayConverter <= Array template < int N, @@ -2079,7 +2077,7 @@ struct PackedNumericArrayConverter { } CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const{ return convert(s); } }; diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index 55555ec561..18715ae790 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -46,6 +46,9 @@ struct sizeof_bits { static int const value = int(sizeof(T) * 8); }; +template +struct sizeof_bits: sizeof_bits {}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -63,6 +66,15 @@ struct sizeof_bits { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Returns the number of bytes required to hold a specified number of bits +CUTLASS_HOST_DEVICE +constexpr int +bits_to_bytes(int bits) { + return (bits + 7) / 8; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + template struct index_sequence; @@ -89,6 +101,5 @@ using make_index_sequence = typename index_sequence_helper::type; #include "cutlass/bfloat16.h" #include "cutlass/tfloat32.h" #include "cutlass/float8.h" - ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index 807a13992b..e86d04ce5b 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -31,9 +31,12 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/detail/dependent_false.hpp" #include "cute/numeric/integral_constant.hpp" #include "cute/arch/cluster_sm90.hpp" #include "cutlass/arch/barrier.h" +#include "cute/util/type_traits.hpp" +#include "cute/container/array.hpp" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -95,20 +98,18 @@ struct PipelineState { static constexpr uint32_t Stages = Stages_; -private: int index_ = 0; uint32_t phase_ = 0; - uint32_t phase_count_ = 0; + uint32_t count_ = 0; -public: CUTLASS_DEVICE - PipelineState(): index_{}, phase_{}, phase_count_{} {} + PipelineState(): index_{}, phase_{}, count_{} {} CUTLASS_DEVICE - PipelineState(int index, uint32_t phase, uint32_t phase_count) + PipelineState(int index, uint32_t phase, uint32_t count) : index_(index) , phase_(phase) - , phase_count_(phase_count) {} + , count_(count) {} CUTLASS_DEVICE int index() const { @@ -121,18 +122,18 @@ struct PipelineState { } CUTLASS_DEVICE - uint32_t phase_count() const { - return phase_count_; + uint32_t count() const { + return count_; } CUTLASS_DEVICE void operator++() { if constexpr (Stages > 0) { ++index_; + ++count_; if (index_ == Stages) { index_ = 0; phase_ ^= 1; - ++phase_count_; } } } @@ -141,7 +142,7 @@ struct PipelineState { PipelineState& operator=(const PipelineState& other) { index_ = other.index(); phase_ = other.phase(); - phase_count_ = other.phase_count(); + count_ = other.count(); return *this; } @@ -157,8 +158,8 @@ struct PipelineState { if ((num_iterations >= Stages) && (((index_ + num_iterations) / Stages) % 2) == 1) { phase_ ^= 1; } - phase_count_ += (index_ + num_iterations) / Stages; index_ = (index_ + num_iterations) % Stages; + count_ += num_iterations; } return *this; } @@ -175,8 +176,8 @@ PipelineState make_producer_start_state() { // Producer starts with an opposite phase as the buffers are initially empty constexpr int InitialProducerStage = 0; constexpr uint32_t InitialProducerPhase = 1; - constexpr uint32_t InitialProducerPhaseCount = 0; - return {InitialProducerStage, InitialProducerPhase, InitialProducerPhaseCount}; + constexpr uint32_t InitialProducerCount = 0; + return {InitialProducerStage, InitialProducerPhase, InitialProducerCount}; } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -199,6 +200,7 @@ public : using ProducerBarrierType = FullBarrier::ValueType; using ConsumerBarrierType = EmptyBarrier::ValueType; static constexpr uint32_t Stages = Stages_; + using PipelineState = cutlass::PipelineState; struct SharedStorage { FullBarrier full_barrier_[Stages]; @@ -217,6 +219,7 @@ public : ThreadCategory role = ThreadCategory::NonParticipant; uint32_t is_leader = 0; uint32_t num_consumers = 0; + cute::tuple active_warps = {0, 0}; }; // Constructor @@ -229,21 +232,19 @@ public : int warp_idx = canonical_warp_idx(); int lane_predicate = cute::elect_one_sync(); auto cluster_shape = ClusterShape{}; - - if (warp_idx == 0 && lane_predicate == 1) { + if (warp_idx == cute::get<0>(params.active_warps) && lane_predicate == 1) { // Barrier FULL init for (int i = 0; i < Stages; ++i) { full_barrier_ptr_[i].init(1); } - // Barrier EMPTY init uint32_t const num_consumer_warpgroups_per_cluster = params_.num_consumers / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) + cute::size<1>(cluster_shape) - 1) * num_consumer_warpgroups_per_cluster; + // Barrier EMPTY init for (int i = 0; i < Stages; ++i) { empty_barrier_ptr_[i].init(multicast_consumer_arrival_count); } } - // Logic to optimally schedule Empty Arrives // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) dim3 block_id = cute::block_id_in_cluster(); @@ -285,8 +286,10 @@ public : CUTLASS_DEVICE bool is_same_row_or_col(int dst_block_id, dim3 block_id, ClusterShape cluster_shape) { - return ((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x || - (dst_block_id / cute::size<0>(cluster_shape)) == block_id.y); + return (((dst_block_id % cute::size<0>(cluster_shape)) == block_id.x) || + ( + ((dst_block_id / cute::size<0>(cluster_shape)) == block_id.y) + )); } //////////////////// @@ -310,24 +313,24 @@ public : // The finalize function will return immediately in that case. CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { return producer_try_acquire(state.index(), state.phase(), skip_wait); } CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { producer_acquire(state.index(), state.phase(), barrier_token); } CUTLASS_DEVICE - void producer_commit(PipelineState state, uint32_t bytes) { + void producer_commit(PipelineState state, uint32_t bytes) { producer_commit(state.index(), bytes); } // Prevents early exit of producer blocks in Cluster. // This should be called once before kernel exits. CUTLASS_DEVICE - void producer_tail(PipelineState state) { + void producer_tail(PipelineState state) { for (int count = 0; count < Stages; ++count) { producer_acquire(state); ++state; @@ -335,7 +338,7 @@ public : } CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { + ProducerBarrierType* producer_get_barrier(PipelineState state) { return producer_get_barrier(state.index()); } @@ -343,22 +346,22 @@ public : // Consumer APIs //////////////////// CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { return consumer_try_wait(state.index(), state.phase(), skip_wait); } CUTLASS_DEVICE - void consumer_wait(PipelineState state) { + void consumer_wait(PipelineState state) { consumer_wait(state.index(), state.phase()); } CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token) { + void consumer_wait(PipelineState state, ConsumerToken barrier_token) { consumer_wait(state.index(), state.phase(), barrier_token); } CUTLASS_DEVICE - void consumer_release(PipelineState state) { + void consumer_release(PipelineState state) { consumer_release(state.index()); } @@ -385,7 +388,7 @@ private : } if (params_.is_leader) { - full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes); + full_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); } #ifndef NDEBUG if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { @@ -406,7 +409,7 @@ private : #if CUTLASS_UNIT_TEST_PIPELINE if (params_.is_leader) { // STEP 1 : Commit to self - full_barrier_ptr_[stage].commit(bytes); + full_barrier_ptr_[stage].complete_transaction(bytes); // STEP 2 : Commit to other blocks in our cluster auto cluster_shape = ClusterShape{}; @@ -416,13 +419,13 @@ private : CUTLASS_PRAGMA_UNROLL for(int n = 0; n < size<1>(block_layout_in_cluster); ++n) { uint32_t dst_block_id = block_layout_in_cluster(local_block_id.x,n,Int<0>{}); - full_barrier_ptr_[stage].commit(dst_block_id, bytes, n!=local_block_id.y); + full_barrier_ptr_[stage].complete_transaction(dst_block_id, bytes, n!=local_block_id.y); } CUTLASS_PRAGMA_UNROLL for(int m = 0; m < size<0>(block_layout_in_cluster); ++m) { uint32_t dst_block_id = block_layout_in_cluster(m,local_block_id.y,Int<0>{}); - full_barrier_ptr_[stage].commit(dst_block_id, bytes, m!=local_block_id.x); + full_barrier_ptr_[stage].complete_transaction(dst_block_id, bytes, m!=local_block_id.x); } } #endif @@ -440,10 +443,7 @@ private : // Wait for producer to commit transactions (done by TMA) CUTLASS_DEVICE void consumer_wait(uint32_t stage, uint32_t phase) { - uint32_t done = full_barrier_ptr_[stage].test_wait(phase); - if (not done) { - full_barrier_ptr_[stage].wait(phase); - } + full_barrier_ptr_[stage].wait(phase); } // Wait for producer to commit transactions (done by TMA) @@ -474,16 +474,22 @@ private : /////////////////////////////////////////////////////////////////////////////////////////////////// // -// TMA store (consumer) pipeline class +// TMA store pipeline class // producer-only class, no async barriers between threads because consumer is TMA unit // /////////////////////////////////////////////////////////////////////////////////////////////////// template < - int Stages_ + int Stages_, + // The number of committed TMA store batches that can be in flight upon return of producer acquire + int UnacquiredStages_ = Stages_-1 > class PipelineTmaStore { public: static constexpr uint32_t Stages = Stages_; + static_assert(Stages_ > 0); + static_assert(UnacquiredStages_ >= 0); + static constexpr uint32_t UnacquiredStages = static_cast(UnacquiredStages_); + using PipelineState = cutlass::PipelineState; struct Params { bool always_wait = false; @@ -497,19 +503,19 @@ class PipelineTmaStore { //////////////////// // Wait for the least recently committed batch of TMA stores to complete CUTLASS_DEVICE - void producer_acquire(PipelineState state) { - producer_acquire(state.index(), state.phase_count()); + void producer_acquire(PipelineState state) { + producer_acquire(state.index(), state.count()); } // Commit the most recently issued batch of TMA stores CUTLASS_DEVICE - void producer_commit(PipelineState state) { - producer_commit(state.index(), state.phase_count()); + void producer_commit(PipelineState state) { + producer_commit(state.index(), state.count()); } // Wait for all TMA stores to complete CUTLASS_DEVICE - void producer_tail([[maybe_unused]] PipelineState state) { + void producer_tail([[maybe_unused]] PipelineState state) { tma_store_wait<0>(); } @@ -517,20 +523,71 @@ class PipelineTmaStore { Params params_; // Wait for the least recently committed batch of TMA stores to complete + // or until at most UnacquiredStages TMA store batches are in-flight (if specified) CUTLASS_DEVICE - void producer_acquire([[maybe_unused]] uint32_t stage, uint32_t phase_count) { - if (params_.always_wait || phase_count > 0) { - tma_store_wait(); + void producer_acquire([[maybe_unused]] uint32_t stage, uint32_t count) { + if (params_.always_wait || count > UnacquiredStages) { + tma_store_wait(); } } // Commit the most recently issued batch of TMA stores CUTLASS_DEVICE - void producer_commit([[maybe_unused]] uint32_t stage, [[maybe_unused]] uint32_t phase_count) { + void producer_commit([[maybe_unused]] uint32_t stage, [[maybe_unused]] uint32_t count) { + tma_store_arrive(); + } +}; + +template <> +class PipelineTmaStore< /* Stages_ = */ 0, /* UnacquiredStages = Stages_ - 1 = */ -1 > { +public: + static constexpr uint32_t Stages = 0; + static constexpr uint32_t UnacquiredStages = 0; + using PipelineState = cutlass::PipelineState; + + struct Params { + bool always_wait = false; + }; + + PipelineTmaStore() = default; + CUTLASS_DEVICE + PipelineTmaStore(Params params) : params_(params) {} + + //////////////////// + // Producer APIs + //////////////////// + + template + CUTLASS_DEVICE + void producer_acquire(PipelineState /* state */, + ThisTemplateParameterExistsOnlyForDependentFalse* /* unused */ = nullptr) { + static_assert(cutlass::detail::dependent_false, + "It is never valid to call PipelineTmaStore<0>::producer_acquire"); + } + + // Commit the most recently issued batch of TMA stores + CUTLASS_DEVICE + void producer_commit(PipelineState state) { + producer_commit(state.index(), state.count()); + } + + // Wait for all TMA stores to complete + CUTLASS_DEVICE + void producer_tail([[maybe_unused]] PipelineState state) { + tma_store_wait<0>(); + } + +private: + Params params_; + + // Commit the most recently issued batch of TMA stores + CUTLASS_DEVICE + void producer_commit([[maybe_unused]] uint32_t stage, [[maybe_unused]] uint32_t count) { tma_store_arrive(); } }; + /////////////////////////////////////////////////////////////////////////////////////////////////// // // Simple producer-consumer async Pipeline class using producer transaction barriers @@ -544,10 +601,11 @@ public : using ProducerBarrierType = FullBarrier::ValueType; using ConsumerBarrierType = EmptyBarrier::ValueType; static constexpr uint32_t Stages = Stages_; + using PipelineState = cutlass::PipelineState; struct SharedStorage { - FullBarrier full_barrier_[Stages]; - EmptyBarrier empty_barrier_[Stages]; + cute::array full_barrier_; + cute::array empty_barrier_; }; enum class ThreadCategory { @@ -563,27 +621,27 @@ public : uint32_t producer_arv_count = 1; uint32_t consumer_arv_count = 1; uint32_t dst_blockid = cute::block_rank_in_cluster(); + cute::tuple active_warps = {0, 0}; }; // Constructor CUTLASS_DEVICE PipelineTransactionAsync(SharedStorage& storage, Params const& params) : params_(params) - , full_barrier_ptr_(&storage.full_barrier_[0]) - , empty_barrier_ptr_(&storage.empty_barrier_[0]) { + , full_barrier_ptr_(storage.full_barrier_.data()) + , empty_barrier_ptr_(storage.empty_barrier_.data()) { int warp_idx = canonical_warp_idx(); int lane_predicate = cute::elect_one_sync(); // Barrier FULL, EMPTY init // Init is done only by thread 0 of the block - if (warp_idx == 0 && lane_predicate == 1) { + if (warp_idx == cute::get<0>(params.active_warps) && lane_predicate == 1) { for (int i = 0; i < Stages; ++i) { full_barrier_ptr_[i].init(params.producer_arv_count); empty_barrier_ptr_[i].init(params.consumer_arv_count); } } - cutlass::arch::fence_barrier_init(); } @@ -607,24 +665,30 @@ public : // then it is still correct to pass it into the finalize function. // The finalize function will return immediately in that case. CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { return producer_try_acquire(state.index(), state.phase(), skip_wait); } CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { producer_acquire(state.index(), state.phase(), barrier_token); } + // Perform an expect-tx operation on the stage's full barrier. Must be called by 1 thread + CUTLASS_DEVICE + void producer_expect_transaction(PipelineState state) { + producer_expect_transaction(state.index()); + } + CUTLASS_DEVICE - void producer_commit(PipelineState state) { + void producer_commit(PipelineState state) { producer_commit(state.index()); } // Prevents early exit of producer blocks in Cluster. // This should be called once before kernel exits. CUTLASS_DEVICE - void producer_tail(PipelineState state) { + void producer_tail(PipelineState state) { for (int count = 0; count < Stages; ++count) { producer_acquire(state); ++state; @@ -632,7 +696,7 @@ public : } CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { + ProducerBarrierType* producer_get_barrier(PipelineState state) { return producer_get_barrier(state.index()); } @@ -640,17 +704,17 @@ public : // Consumer APIs //////////////////// CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { return consumer_try_wait(state.index(), state.phase(), skip_wait); } CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { consumer_wait(state.index(), state.phase(), barrier_token); } CUTLASS_DEVICE - void consumer_release(PipelineState state) { + void consumer_release(PipelineState state) { consumer_release(state.index()); } @@ -673,12 +737,17 @@ public : if (barrier_token == BarrierStatus::WaitAgain) { empty_barrier_ptr_[stage].wait(phase); } + } - full_barrier_ptr_[stage].arrive_and_reset_bytes(params_.transaction_bytes, params_.dst_blockid); + // Perform an expect-tx operation on the stage's full barrier. Must be called by 1 thread + CUTLASS_DEVICE + void producer_expect_transaction(uint32_t stage) { + full_barrier_ptr_[stage].expect_transaction(params_.transaction_bytes); } CUTLASS_DEVICE - void producer_commit([[maybe_unused]] uint32_t stage) { + void producer_commit(uint32_t stage) { + full_barrier_ptr_[stage].arrive(params_.dst_blockid); } CUTLASS_DEVICE @@ -721,6 +790,7 @@ public : using ProducerBarrierType = FullBarrier::ValueType; using ConsumerBarrierType = EmptyBarrier::ValueType; static constexpr uint32_t Stages = Stages_; + using PipelineState = cutlass::PipelineState; struct SharedStorage { FullBarrier full_barrier_[Stages]; @@ -739,6 +809,7 @@ public : uint32_t producer_arv_count = 1; uint32_t consumer_arv_count = 1; uint32_t dst_blockid = cute::block_rank_in_cluster(); + cute::tuple active_warps = {0, 0}; }; // Default assumption when only storage is passed is : @@ -760,13 +831,12 @@ public : // Barrier FULL, EMPTY init // Init is done only by thread 0 of the block - if (warp_idx == 0 && lane_predicate == 1) { + if (warp_idx == cute::get<0>(params.active_warps) && lane_predicate == 1) { for (int i = 0; i < Stages; ++i) { full_barrier_ptr_[i].init(params.producer_arv_count); empty_barrier_ptr_[i].init(params.consumer_arv_count); } } - cutlass::arch::fence_barrier_init(); } @@ -790,24 +860,24 @@ public : // then it is still correct to pass it into the finalize function. // The finalize function will return immediately in that case. CUTLASS_DEVICE - ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { + ProducerToken producer_try_acquire(PipelineState state, uint32_t skip_wait = false) { return producer_try_acquire(state.index(), state.phase(), skip_wait); } CUTLASS_DEVICE - void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { producer_acquire(state.index(), state.phase(), barrier_token); } CUTLASS_DEVICE - void producer_commit(PipelineState state) { + void producer_commit(PipelineState state) { producer_commit(state.index()); } // Prevents early exit of producer blocks in Cluster. // This should be called once before kernel exits. CUTLASS_DEVICE - void producer_tail(PipelineState state) { + void producer_tail(PipelineState state) { for (int count = 0; count < Stages; ++count) { producer_acquire(state); ++state; @@ -815,7 +885,7 @@ public : } CUTLASS_DEVICE - ProducerBarrierType* producer_get_barrier(PipelineState state) { + ProducerBarrierType* producer_get_barrier(PipelineState state) { return producer_get_barrier(state.index()); } @@ -823,17 +893,17 @@ public : // Consumer APIs //////////////////// CUTLASS_DEVICE - ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { + ConsumerToken consumer_try_wait(PipelineState state, uint32_t skip_wait = false) { return consumer_try_wait(state.index(), state.phase(), skip_wait); } CUTLASS_DEVICE - void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { consumer_wait(state.index(), state.phase(), barrier_token); } CUTLASS_DEVICE - void consumer_release(PipelineState state) { + void consumer_release(PipelineState state) { consumer_release(state.index()); } @@ -919,6 +989,7 @@ public : struct Params { uint32_t group_id; uint32_t group_size; + cute::tuple active_warps = {0, 0}; }; private : @@ -944,20 +1015,18 @@ private : barrier_ptr_(&storage.barrier_[0][0]), // Group 0 - starts with an opposite phase stage_({0, params.group_id == 0, 0}) { - int warp_idx = canonical_warp_idx(); int lane_predicate = cute::elect_one_sync(); // Barrier FULL, EMPTY init // Init is done only by the one elected thread of the block - if (warp_idx == 0 && lane_predicate == 1) { + if (warp_idx == cute::get<0>(params.active_warps) && lane_predicate == 1) { for (int d = 0; d < Depth; ++d) { for (int l = 0; l < Length; ++l) { barrier_ptr_[d * Length + l].init(params.group_size); } } } - cutlass::arch::fence_barrier_init(); } diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index f0582f043a..f2cd5627cd 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -90,9 +90,6 @@ * - \p alignment_of * - \p aligned_storage * - * (4) Functions and types that are STL-like (but aren't in the STL): - * - \p TODO: min and max functors? - * * The idea is that, as we drop support for older compilers, we can simply #define * the \p __NV_STD_XYZ macros and \p platform namespace to alias their C++ * counterparts (or trivially find-and-replace their occurrences in code text). @@ -103,7 +100,11 @@ //----------------------------------------------------------------------------- #if defined(__CUDACC_RTC__) +#include +#include +#include #include +#include #else #include #endif @@ -135,6 +136,24 @@ /****************************************************************************** * Macros ******************************************************************************/ +/// std +#if !defined(CUTLASS_STL_NAMESPACE) +#if defined(__CUDACC_RTC__) +#define CUTLASS_STL_NAMESPACE cuda::std +#else +#define CUTLASS_STL_NAMESPACE std +#endif +#endif + +/// builtin_unreachable +#if !defined(CUTLASS_GCC_UNREACHABLE) +# if defined(__clang__) || defined(__GNUC__) +# define CUTLASS_GCC_UNREACHABLE __builtin_unreachable() +# else +# define CUTLASS_GCC_UNREACHABLE +# endif +#endif + //----------------------------------------------------------------------------- // Keywords //----------------------------------------------------------------------------- @@ -366,6 +385,9 @@ using std::conditional; #endif +/// std::conditional_t +using CUTLASS_STL_NAMESPACE::conditional_t; + //----------------------------------------------------------------------------- // Const/volatility specifiers //----------------------------------------------------------------------------- @@ -410,6 +432,23 @@ using std::remove_cv; #endif +/// std::remove_cv_t +using CUTLASS_STL_NAMESPACE::remove_cv_t; +/// std::remove_reference_t +using CUTLASS_STL_NAMESPACE::remove_reference_t; + +// C++20 +// using std::remove_cvref; +template +struct remove_cvref { + using type = remove_cv_t>; +}; + +// C++20 +// using std::remove_cvref_t; +template +using remove_cvref_t = typename remove_cvref::type; + //----------------------------------------------------------------------------- // Type relationships //----------------------------------------------------------------------------- @@ -574,6 +613,11 @@ using std::is_trivially_copyable; #endif +/// std::is_unsigned_v +using CUTLASS_STL_NAMESPACE::is_integral_v; +/// std::is_unsigned_v +using CUTLASS_STL_NAMESPACE::is_unsigned_v; + //----------------------------------------------------------------------------- // bit_cast //----------------------------------------------------------------------------- @@ -889,5 +933,19 @@ struct numeric_limits { }; #endif +/// std::float_round_style +using CUTLASS_STL_NAMESPACE::float_round_style; +using CUTLASS_STL_NAMESPACE::round_indeterminate; +using CUTLASS_STL_NAMESPACE::round_toward_zero; +using CUTLASS_STL_NAMESPACE::round_to_nearest; +using CUTLASS_STL_NAMESPACE::round_toward_infinity; +using CUTLASS_STL_NAMESPACE::round_toward_neg_infinity; + +/// std::float_denorm_style +using CUTLASS_STL_NAMESPACE::float_denorm_style; +using CUTLASS_STL_NAMESPACE::denorm_indeterminate; +using CUTLASS_STL_NAMESPACE::denorm_absent; +using CUTLASS_STL_NAMESPACE::denorm_present; + } // namespace platform } // namespace cutlass diff --git a/include/cutlass/quaternion.h b/include/cutlass/quaternion.h index 1015be4bf0..3e4e88a877 100644 --- a/include/cutlass/quaternion.h +++ b/include/cutlass/quaternion.h @@ -610,7 +610,6 @@ Quaternion operator/(Element s, Quaternion const &q) { template CUTLASS_HOST_DEVICE bool operator<(Quaternion const &lhs, Quaternion const &rhs) { - //TODO return true; } diff --git a/include/cutlass/subbyte_reference.h b/include/cutlass/subbyte_reference.h index 58c460a000..ba6fe3aa75 100644 --- a/include/cutlass/subbyte_reference.h +++ b/include/cutlass/subbyte_reference.h @@ -34,6 +34,7 @@ #pragma once #include "cutlass/numeric_types.h" +#include "cutlass/fast_math.h" namespace cutlass { @@ -60,8 +61,9 @@ namespace cutlass { /// template < typename Element_, /// CUTLASS numeric element type. - typename Storage_ = uint8_t /// Underlying storage type. Must be able to hold an integer + typename Storage_ = uint8_t, /// Underlying storage type. Must be able to hold an integer /// number of objects of type Element. + class = void > class ConstSubbyteReference { public: @@ -306,6 +308,8 @@ template < #else uint8_t #endif + , + class = void > class SubbyteReference { public: @@ -602,6 +606,673 @@ class SubbyteReference { ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ /// Underlying basic storage type. +> +class SubbyteReference::value % sizeof_bits::value != 0>::type> { +public: + + using Element = Element_; + ///! Note: Storage unit could not be divisibale by Element, + /// Type element may be stored across 2 storage units, so need a storage vector to hold integer + /// number of objects of type Element. + using StorageUnit = Storage_; + static int const kBitsStoredVec = cutlass::lcm(sizeof_bits::value, sizeof_bits::value); + static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits::value; + + using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; + using StorageVecPointer = StorageVec *; + + using CudaAtomicType = typename platform::conditional< + sizeof_bits::value == 16, + uint32_t, + uint64_t + >::type; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than StorageVec."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "StorageVec must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask for storage unit. + StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); + + /// Pointer to array containing element + StorageVecPointer ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + + /// Element may be stored across 2 storage unit. + /// Low storage unit index in StorageVec + /// High storage unit index in StorageVec + int low_storage_unit_idx_; + int high_storage_unit_idx_; + + /// Full Mask to extract the entire element + uint64_t full_element_mask_; + + /// Mask to extract the Element from Low storage unit and High storage unit. + StorageUnit low_storage_mask_; + StorageUnit high_storage_mask_; + + /// Start bit index inside the storage unit. + int start_bit_idx_; + +private: + + CUTLASS_HOST_DEVICE + void update_element_status() { + int num_bits = offset_ * sizeof_bits::value; + + start_bit_idx_ = num_bits % sizeof_bits::value; + + low_storage_unit_idx_ = num_bits / sizeof_bits::value; + high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value + ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; + + full_element_mask_ = uint64_t(kMask) << start_bit_idx_; + low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); + high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); + } + +public: + + CUTLASS_HOST_DEVICE + SubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + + update_element_status(); + } + + /// Constructor + CUTLASS_HOST_DEVICE + SubbyteReference( + Element *ptr = nullptr + ): SubbyteReference(ptr, 0) { } + + /// Gets StorageVec pointer + CUTLASS_HOST_DEVICE + StorageVecPointer storage_pointer() const { + return ptr_; + } + + /// Gets StorageVec pointer + CUTLASS_HOST_DEVICE + Element * operator&() const { + return reinterpret_cast(ptr_); + } + + /// Gets element offset within StorageVec vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; + StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; + + uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; + uint8_t result = uint8_t(full_item >> start_bit_idx_); + + return reinterpret_cast(result); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference & set(Element const &x) { + + uint64_t item = static_cast((reinterpret_cast(x) & kMask)) << start_bit_idx_; + + StorageUnit low_new_bits = StorageUnit(item & ~StorageUnit(0)); + StorageUnit high_new_bits = StorageUnit(item >> sizeof_bits::value); + + StorageUnit const kLowUpdateMask = StorageUnit((~full_element_mask_) & (~StorageUnit(0))); + StorageUnit const kHighUpdateMask = StorageUnit(((~full_element_mask_) >> sizeof_bits::value) & (~StorageUnit(0))); + +#if defined(__CUDA_ARCH__) + // + // Homebrew read-modify-write + // + if(high_storage_unit_idx_ != low_storage_unit_idx_){ + /// Only need update 2 storage unit at once. + CudaAtomicType original, updated; + do { + StorageUnit original_low_bits = ((*ptr_)[low_storage_unit_idx_]); + StorageUnit original_high_bits = ((*ptr_)[high_storage_unit_idx_]); + + original = (CudaAtomicType(original_low_bits) << sizeof_bits::value) | original_low_bits; + + + StorageUnit update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits; + StorageUnit update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits; + + updated = (CudaAtomicType(update_high_bits) << sizeof_bits::value) | update_low_bits; + + original = atomicCAS(reinterpret_cast(ptr_), original, updated); + + } while (updated != original); + } + else { + /// Only need update 1 storage unit. + StorageUnit original, updated; + do { + original = ((*ptr_)[low_storage_unit_idx_]); + + updated = (original & kLowUpdateMask) | low_new_bits; + + original = atomicCAS(reinterpret_cast(ptr_), original, updated); + + } while (updated != original); + } +#else + + StorageUnit update_low_bits = ((*ptr_)[low_storage_unit_idx_] & kLowUpdateMask) | low_new_bits; + StorageUnit update_high_bits = ((*ptr_)[high_storage_unit_idx_] & kHighUpdateMask) | high_new_bits; + + (*ptr_)[low_storage_unit_idx_] = update_low_bits; + + if(low_storage_unit_idx_ != high_storage_unit_idx_) + (*ptr_)[high_storage_unit_idx_] = update_high_bits; +#endif + + return *this; + } + + //// + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(Element const & x) { + return set(x); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=(SubbyteReference const & x) { + return set(x.get()); + } + + /// Stores an element to memory + CUTLASS_HOST_DEVICE + SubbyteReference &operator=( + ConstSubbyteReference const &x) { + return set(x.get()); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + SubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator+(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-(int offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + SubbyteReference operator-=(long long offset) const { + + SubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(SubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +template < + typename Element_, /// CUTLASS numeric element type. + typename Storage_ /// Underlying storage type. Must be able to hold an integer +> +class ConstSubbyteReference::value % sizeof_bits::value != 0>::type> { +public: + + using Element = Element_; + ///! Note: Storage unit could not be divisibale by Element, + /// Type element may be stored across 2 storage units, so need a storage vector to hold integer + /// number of objects of type Element. + using StorageUnit = Storage_; + static int const kBitsStoredVec = cutlass::lcm(sizeof_bits::value, sizeof_bits::value); + static int const kNumStorageUnitPerStoredVec = kBitsStoredVec / sizeof_bits::value; + + using StorageVec = StorageUnit[kNumStorageUnitPerStoredVec]; + using StorageVecPointer = StorageVec const *; + + using CudaAtomicType = typename platform::conditional< + sizeof_bits::value == 16, + uint32_t, + uint64_t + >::type; + + static_assert(sizeof_bits::value <= sizeof_bits::value, + "Size of Element must not be greater than StorageVec."); + + static_assert(!(sizeof_bits::value % sizeof_bits::value), + "StorageVec must be divisible by Element"); + +private: + + ///! Number of elements per storage vector + int const kElementsPerVector = sizeof_bits::value / sizeof_bits::value; + + ///! Bit mask for storage unit. + StorageUnit const kMask = (StorageUnit(1) << sizeof_bits::value) - StorageUnit(1); + + /// Pointer to array containing element + StorageVecPointer ptr_; + + /// Offset (in units of elements) from pointer. + /// + /// Invariant: must always be in range [0, kElementsPerVector) + int offset_; + + /// Element may be stored across 2 storage unit. + /// Low storage unit index in StorageVec + /// High storage unit index in StorageVec + int low_storage_unit_idx_; + int high_storage_unit_idx_; + + /// Full Mask to extract the entire element + uint64_t full_element_mask_; + + /// Mask to extract the Element from Low storage unit and High storage unit. + StorageUnit low_storage_mask_; + StorageUnit high_storage_mask_; + + /// Start bit index inside the storage unit. + int start_bit_idx_; + +private: + + CUTLASS_HOST_DEVICE + void update_element_status() { + int num_bits = offset_ * sizeof_bits::value; + + start_bit_idx_ = num_bits % sizeof_bits::value; + + low_storage_unit_idx_ = num_bits / sizeof_bits::value; + high_storage_unit_idx_ = sizeof_bits::value - (start_bit_idx_) < sizeof_bits::value + ? low_storage_unit_idx_ + 1 : low_storage_unit_idx_; + + full_element_mask_ = uint64_t(kMask) << start_bit_idx_; + low_storage_mask_ = StorageUnit(full_element_mask_ & ~StorageUnit(0)); + high_storage_mask_ = StorageUnit((full_element_mask_ >> sizeof_bits::value) & ~StorageUnit(0)); + } + +public: + + CUTLASS_HOST_DEVICE + ConstSubbyteReference(): ptr_(nullptr), offset_(0) { } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element const *ptr, /// pointer to memory + int64_t offset /// logical offset in units of Element + ): + ptr_(reinterpret_cast(ptr)), + offset_(0) { + + int64_t offset_in_vectors = offset / kElementsPerVector; + int64_t offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = int(offset_in_elements); + + update_element_status(); + } + + /// Constructor + CUTLASS_HOST_DEVICE + ConstSubbyteReference( + Element *ptr = nullptr + ): ConstSubbyteReference(ptr, 0) { } + + /// Gets storage pointer + CUTLASS_HOST_DEVICE + StorageVecPointer storage_pointer() const { + return ptr_; + } + + /// Gets element offset within storage vector + CUTLASS_HOST_DEVICE + int element_offset() const { + return offset_; + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + Element get() const { + StorageUnit low_bits = (*ptr_)[low_storage_unit_idx_] & low_storage_mask_; + StorageUnit high_bits = low_storage_unit_idx_ != high_storage_unit_idx_ ? (*ptr_)[high_storage_unit_idx_] & high_storage_mask_ : 0; + + uint64_t full_item = ((uint64_t)high_bits << sizeof_bits::value) | low_bits; + uint8_t result = uint8_t(full_item >> start_bit_idx_); + + return reinterpret_cast(result); + } + + /// Unpacks an element from memory + CUTLASS_HOST_DEVICE + operator Element() const { + return get(); + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(int offset) { + + offset += offset_; + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator+=(long long offset) { + + offset += offset_; + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ += offset_in_vectors; + offset_ = offset_in_elements; + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(int offset) { + + int offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = offset % kElementsPerVector; + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + + return *this; + } + + /// Adds an offset in units of elements to the reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference &operator-=(long long offset) { + + long long offset_in_vectors = offset / kElementsPerVector; + int offset_in_elements = int(offset % kElementsPerVector); + + ptr_ -= offset_in_vectors; + offset_ -= offset_in_elements; + + if (offset_ < 0) { + offset_ += kElementsPerVector; + --ptr_; + } + + update_element_status(); + + return *this; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator+(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref += offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-(int offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Returns a reference to an element with a given offset from the current reference + CUTLASS_HOST_DEVICE + ConstSubbyteReference operator-=(long long offset) const { + + ConstSubbyteReference ref(ptr_, offset_); + ref -= offset; + + return ref; + } + + /// Computes the difference in elements between references + CUTLASS_HOST_DEVICE + ptrdiff_t operator-(ConstSubbyteReference ref) const { + return (ptr_ - ref.ptr_) * kElementsPerVector + (offset_ - ref.offset_); + } + + /// Explicit cast to int + CUTLASS_HOST_DEVICE + explicit operator int() const { + return int(get()); + } + + /// Explicit cast to signed 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator int64_t() const { + return int64_t(get()); + } + + /// Explicit cast to unsigned 64-bit integer + CUTLASS_HOST_DEVICE + explicit operator uint64_t() const { + return uint64_t(get()); + } + + /// Explicit cast to float + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + + /// Explicit cast to double + CUTLASS_HOST_DEVICE + explicit operator double() const { + return double(get()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + template ::value < 8)> struct ReferenceFactory; diff --git a/include/cutlass/thread/matrix.h b/include/cutlass/thread/matrix.h index bc78cf8512..4793fcbf27 100644 --- a/include/cutlass/thread/matrix.h +++ b/include/cutlass/thread/matrix.h @@ -134,7 +134,6 @@ class Matrix : public Array { /// Ctor CUTLASS_HOST_DEVICE Matrix(Diagonal const &diag) { - // Todo - construct from diagonal } /// Returns a TensorRef pointing to the first element of the tensor. diff --git a/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp index 63de672635..27039c64f5 100644 --- a/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp +++ b/include/cutlass/transform/collective/sm90_wgmma_transpose.hpp @@ -92,60 +92,470 @@ use_universal_transposition() { } } -/// Transpose B operand in SMEM -template < - class TensorSmemB, - class TensorTransposedSmemB, - class PipelineState, - class TiledMma, - class SmemLayoutB, - class SmemLayoutAtomB, - class ElementB> -CUTLASS_DEVICE void -transpose_b_operand ( - TensorSmemB const& sB, - TensorTransposedSmemB const& gmma_sB, - PipelineState const& smem_pipe_read, - int warp_idx, int warp_group_thread_idx, - TiledMma, SmemLayoutB, SmemLayoutAtomB, ElementB) -{ - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// Important terms: - /// WarpgroupTileSize : The warp_group_tile size (WarpgroupTileSize x WarpgroupTileSize) a warp group would transpose - /// WarpTileSize : The warp_tile size (WarpTile x WarpTile) a warp would transpose - /// Step : The number of steps a warp group takes to complete the entire warp_group_tile transposition. - /// WarpTileNCoordLUT : The look up table to store the n-dim coords used by the warps - /// WarpTileKCoordLUT : The look up table to store the k-dim coords used by the warps - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - static_assert(size(TiledMma{}) == NumThreadsPerWarpGroup, "Wrong thread number for TransposeB"); - constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. - constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; - - constexpr int BytesPerSmemSwizzleUnit = 16; - constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); - constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; - - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// Optimized transposition, less regs per thread than universal approach, need warp sync between load and store - /// TF32/FP32 would use the 2-steps approach. Fp8/Int8 would use 8-steps approach. - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - if constexpr (!detail::use_universal_transposition()) { - constexpr int Steps = sizeof(ElementB) == 1 ? 8 : 2; - constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); - - constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; - static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class NoTranspositionOperandB { +public: + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + constexpr CUTLASS_HOST_DEVICE + NoTranspositionOperandB( + int, + int, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const&, + TensorTransposedSmemB const&, + int, int) { } + + CUTLASS_DEVICE void synchronize(int) { } + + CUTLASS_DEVICE void synchronize() { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const&, + TensorTransposedSmemB const&, + int) { } +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class UniversalTranspositionOperandB { +public: + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + constexpr CUTLASS_HOST_DEVICE + UniversalTranspositionOperandB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) { + if (current_step > 0) { + return; + } + + constexpr int NumMathWarpGroup = size(TiledMma{}) / NumThreadsPerWarpGroup; + static_assert(NumMathWarpGroup == 1 || + (!detail::use_universal_transposition() && NumMathWarpGroup == 2), + "Wrong math warp group number for TransposeB"); + constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + + constexpr int BytesPerSmemSwizzleUnit = 16; + constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// Universal transposition, need warp_group sync between load and store. + /// The number of reg used depends on the input elementB. + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /* + In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location. + In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements: + K + ------------ + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step 0 + | W0 W1 W2 W3 --- + .... + | W0 W1 W2 W3 --- + | W0 W1 W2 W3 | + | W0 W1 W2 W3 | --> Copy Step n + | W0 W1 W2 W3 --- + */ + static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout."); + constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int{}, Int{})); + + // Get copy tile and partition to each thread + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpgroupThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy."); + + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx); + Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) + Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,read_stage)); // (CPY, CPY_N, CPY_K) + + // Divide partitioned tile to limit register usage + constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB)); + static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM."); + + Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape); + Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape); + auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{})); + + CUTLASS_PRAGMA_NO_UNROLL + for (int step = 0; step < CopySteps; ++step) { + copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment); + + // Make sure all elements are read before being overwritten + __syncthreads(); + + copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step)); + } + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step == 0) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + } + } + + CUTLASS_DEVICE void synchronize() { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + + this->operator()(sB, gmma_sB, read_stage, 0); + synchronize(); + + } + +private: + const int warp_idx; + const int warp_group_thread_idx; +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class AsyncTranspositionOperandB { +public: + + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + static constexpr int Steps = 2; + static constexpr int NumMathWarpGroup = size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + static_assert(NumMathWarpGroup <= 2, + "Wrong math warp group number for TransposeB"); + static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; + + static constexpr int BytesPerSmemSwizzleUnit = 16; + static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; + static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); + + static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); + static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. + static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; + static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; + static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. + static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, + static constexpr int NumBitsPerStep = 3; + static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) + static constexpr int NumBitsPerWarp = 12; + // Number of warp_group_tiles + static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + "Copy size must evenly divide SMEM tile."); + static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + + static_assert(size<2>(typename TiledMma::AtomShape_MNK{}) <= WarpThreadShapeK, + "Need to be able to transpose first k-block in the first step"); + + constexpr CUTLASS_HOST_DEVICE + AsyncTranspositionOperandB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) + , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) + { + if (current_step >= StepsPerWarpGroup) { + return; + } + + static constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// A warp group uses 2 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. + /// In each step, one warp would hold two warp_tiles. + /// Step 0: Step 1: + /// W0 W1 W2 W3 -- -- -- -- + /// W1 W0 -- -- -- -- W3 W2 + /// W2 -- -- -- -- W3 W0 W1 + /// W3 -- -- -- -- W2 W1 W0 + /// + ///////////////////////////////////////////////////////////////////////////////////////////////////////////// + /// + /// Fully static coord LUT to avoid extra register use. + /// [warp_id][step][warp_tile][n / k] + /// Step 0 Step 1 Step 2 Step 3 Step 4 Step 5 Step 6 Step 7 + /// {{{0,0}, {1,1}}, {{2,2}, {3,3}}, {{4,4}, {5,5}}, {{6,6}, {7,7}}, {{4,0}, {0,4}}, {{4,1}, {1,4}}, {{4,2}, {2,4}}, {{4,3}, {3,4}}}, // W0 + /// {{{1,0}, {0,1}}, {{3,2}, {2,3}}, {{5,4}, {4,5}}, {{7,6}, {6,7}}, {{5,0}, {0,5}}, {{5,1}, {1,5}}, {{5,2}, {2,5}}, {{5,3}, {3,5}}}, // W1 + /// {{{2,0}, {0,2}}, {{3,1}, {1,3}}, {{6,4}, {4,6}}, {{7,5}, {5,7}}, {{6,0}, {0,6}}, {{6,1}, {1,6}}, {{6,2}, {2,6}}, {{6,3}, {3,6}}}, // W2 + /// {{{3,0}, {0,3}}, {{2,1}, {1,2}}, {{7,4}, {4,7}}, {{6,5}, {5,6}}, {{7,0}, {0,7}}, {{7,1}, {1,7}}, {{7,2}, {2,7}}, {{7,3}, {3,7}}}, // W3 + /// + /// Encoding the coord of warp tile0 into two int64_t values. + /// Only encoding Step 0 ~ Step 4, since Step 5 ~ Step 7 have a straightforward pattern. + /// Only encoding warp tile0, since the coords of warp tile1 could be easily deduced from warp tile0. + /// The 2-step transposition and the 8-step transposition share the same encoding. + /// + ////////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // Divide entire SMEM to multiple warp_tiles + constexpr auto WarpTileShape = make_shape(Int(), Int()); + Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); + Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); + + // Get copy tile + auto sB_tiled_copy = make_tiled_copy( + Copy_Atom{}, + WarpThreadLayout, // thr_layout + Layout<_1>{} // val_layout + ); + + static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); + auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx + + // Construct fragments for transposition + Tensor tmp_tCsB = sB_thr_copy.partition_S(flatten(s_tile(_, make_coord(_0{}, _0{})))); + decltype(make_fragment_like(tmp_tCsB)) transpose_fragments[TilesPerWarp] = { + make_fragment_like(tmp_tCsB), + make_fragment_like(tmp_tCsB) + }; + + int step = current_step * NumMathWarpGroup; + if constexpr (NumMathWarpGroup == 2) { + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + step += warp_idx / (NumWarpsPerWarpGroup * 2); + } + + int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT >> (NumBitsPerStep * current_step); + int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT >> (NumBitsPerStep * current_step); + + if constexpr (NumMathWarpGroup == 2) { + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + } + + // decoding the warp tile coord. + int warp_tile0_n, warp_tile0_k; + if constexpr (StepsPerWarpGroup <= NumStepsEncoded) { + warp_tile0_n = tmp_warp_tile_n_coord_LUT & MaskPerStep; + warp_tile0_k = tmp_warp_tile_k_coord_LUT & MaskPerStep; + } else { + warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; + warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; + } + + int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; + int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; + + CUTLASS_PRAGMA_UNROLL + for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { + + static_assert(TilesPerWarp == 2); + + // [warp_tile][n/k] + const int warp_tile_coord[TilesPerWarp][2] = { + // n k + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile0_n, warp_tile0_k}, // warp_tile 0 + {warp_group_tile * NumWarpTilePerWarpgroupTile + warp_tile1_n, warp_tile1_k} // warp_tile 1 + }; + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB = sB_thr_copy.partition_S( + flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); + } + + // Make sure elements in two 8x8 warp tiles are all consumed + __syncwarp(); + + CUTLASS_PRAGMA_UNROLL + for (int warp_tile = 0; warp_tile < TilesPerWarp; ++warp_tile) { + Tensor tCsB_transposed = sB_thr_copy.partition_D( + flatten(s_tile_transposed(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) + ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); + } + + } // loop warp_group_tile + } + + CUTLASS_DEVICE void synchronize(int step) { + if (step < StepsPerWarpGroup) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + } + } + + CUTLASS_DEVICE void synchronize() { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < StepsPerWarpGroup; ++i) { + this->operator()(sB, gmma_sB, read_stage, i); + } + synchronize(); + + } +private: + const int warp_idx; + const int warp_group_thread_idx; + const int warp_idx_in_warp_group; + const int current_warp_tile_n_coord_LUT; + const int current_warp_tile_k_coord_LUT; +}; + +template< + class TiledMma_, + class SmemLayoutB_, + class SmemLayoutAtomB_, + class ElementB_> +class AsyncTranspositionOperandB_1BElementB { +public: + + static_assert(sizeof(ElementB_) == 1); + + using TiledMma = TiledMma_; + using SmemLayoutB = SmemLayoutB_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using ElementB = ElementB_; + + static constexpr int Steps = 8; + static constexpr int NumMathWarpGroup = size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + static_assert(NumMathWarpGroup <= 2, + "Wrong math warp group number for TransposeB"); + static constexpr int WarpgroupTileSize = size<1>(SmemLayoutB{}); // A warp group tile would process entire Smem K. + static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; + + static constexpr int BytesPerSmemSwizzleUnit = 16; + static constexpr int WarpThreadShapeN = BytesPerSmemSwizzleUnit / sizeof(ElementB); + static constexpr int WarpThreadShapeK = NumThreadsPerWarp / WarpThreadShapeN; + static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1); + + static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile; + static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." ); + static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. + static constexpr int64_t WarpTileNCoordLUT = 06723763275316420; + static constexpr int64_t WarpTileKCoordLUT = 05410541064206420; + static constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. + static constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, + static constexpr int NumBitsPerStep = 3; + static constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) + static constexpr int NumBitsPerWarp = 12; + // Number of warp_group_tiles + static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, + "Copy size must evenly divide SMEM tile."); + static constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; + + + constexpr CUTLASS_HOST_DEVICE + AsyncTranspositionOperandB_1BElementB( + int warp_idx_, + int warp_group_thread_idx_, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB) + : warp_idx(warp_idx_) + , warp_group_thread_idx(warp_group_thread_idx_) + , warp_idx_in_warp_group(warp_idx_ % NumWarpsPerWarpGroup) + , current_warp_tile_n_coord_LUT((WarpTileNCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) + , current_warp_tile_k_coord_LUT((WarpTileKCoordLUT >> ((warp_idx_ + % NumWarpsPerWarpGroup) * NumBitsPerWarp)) & MaskPerWarp) { } + + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void operator()( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage, int current_step) + { + if (current_step > 0) { + return; + } + constexpr auto WarpThreadLayout = make_layout(make_shape(Int{}, Int{})); - constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step. ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// A warp group uses 2 or 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. - /// In each step, one warp would hold two warp_tiles. - /// Step 0: Step 1: - /// W0 W1 W2 W3 -- -- -- -- - /// W1 W0 -- -- -- -- W3 W2 - /// W2 -- -- -- -- W3 W0 W1 - /// W3 -- -- -- -- W2 W1 W1 - /// OR: + /// A warp group uses 8 steps to transpose the whole WarpgroupTileSize x WarpgroupTileSize. /// Divide a warp_group_tile into 8x8 warp_tiles to futher reduce the reg usage. /// Step 0: Step 1: Step 2: Step 3: /// W0 W1 W2 W3 -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- @@ -183,26 +593,11 @@ transpose_b_operand ( /// The 2-step transposition and the 8-step transposition share the same encoding. /// ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - constexpr int64_t WarpTileNCoordLUT = 06723763275316420; - constexpr int64_t WarpTileKCoordLUT = 05410541064206420; - constexpr int NumStepsEncoded = 4; // Only encoding first 4 steps into LUT. - constexpr int MaskPerStep = 07; // Each step is encoded into 3bits, - constexpr int NumBitsPerStep = 3; - constexpr int MaskPerWarp = 07777; // Each warp has 4 steps(12 bits) - constexpr int NumBitsPerWarp = 12; - - const int current_warp_tile_n_coord_LUT = (WarpTileNCoordLUT >> (warp_idx * NumBitsPerWarp)) & MaskPerWarp; - const int current_warp_tile_k_coord_LUT = (WarpTileKCoordLUT >> (warp_idx * NumBitsPerWarp)) & MaskPerWarp; - - // Number of warp_group_tiles - static_assert(size<0>(SmemLayoutB{}) % WarpgroupTileSize == 0, - "Copy size must evenly divide SMEM tile."); - constexpr int WarpgroupTileNum = size<0>(SmemLayoutB{}) / WarpgroupTileSize; // Divide entire SMEM to multiple warp_tiles constexpr auto WarpTileShape = make_shape(Int(), Int()); - Tensor s_tile = zipped_divide( sB(_,_,smem_pipe_read.index()), WarpTileShape); - Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,smem_pipe_read.index()), WarpTileShape); + Tensor s_tile = zipped_divide( sB(_,_,read_stage), WarpTileShape); + Tensor s_tile_transposed = zipped_divide(gmma_sB(_,_,read_stage), WarpTileShape); // Get copy tile auto sB_tiled_copy = make_tiled_copy( @@ -210,7 +605,7 @@ transpose_b_operand ( WarpThreadLayout, // thr_layout Layout<_1>{} // val_layout ); - static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}), "Wrong thread number in TiledCopy."); + static_assert(size(sB_tiled_copy) * NumWarpsPerWarpGroup == size(TiledMma{}) / NumMathWarpGroup, "Wrong thread number in TiledCopy."); auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx % NumThreadsPerWarp); // slice based on lane_idx // Construct fragments for transposition @@ -224,11 +619,19 @@ transpose_b_operand ( for (int warp_group_tile = 0; warp_group_tile < WarpgroupTileNum; ++warp_group_tile) { int tmp_warp_tile_n_coord_LUT = current_warp_tile_n_coord_LUT; int tmp_warp_tile_k_coord_LUT = current_warp_tile_k_coord_LUT; + constexpr int StepsPerWarpGroup = Steps / NumMathWarpGroup; + + if constexpr (NumMathWarpGroup == 2) { + tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep * (warp_idx / (NumWarpsPerWarpGroup * 2)); + } CUTLASS_PRAGMA_NO_UNROLL - for (int step = 0; step < Steps; ++step) { + for (int step_per_warp_group = 0; step_per_warp_group < StepsPerWarpGroup; ++step_per_warp_group) { + // For 2 math warpgroup, warp idx4~7 is 1st warp group and 8~9 is 2nd, so decide if 2nd warpgroup need warp idx divide 8. + int step = step_per_warp_group * NumMathWarpGroup + warp_idx / (NumWarpsPerWarpGroup * 2); // decoding the warp tile coord. - int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx; + int warp_tile0_n = step < NumStepsEncoded ? (tmp_warp_tile_n_coord_LUT & MaskPerStep) : 4 + warp_idx_in_warp_group; int warp_tile0_k = step < NumStepsEncoded ? (tmp_warp_tile_k_coord_LUT & MaskPerStep) : step - 4; int warp_tile1_n = warp_tile0_n == warp_tile0_k ? warp_tile0_n + 1 : warp_tile0_k; int warp_tile1_k = warp_tile0_n == warp_tile0_k ? warp_tile0_k + 1 : warp_tile0_n; @@ -236,6 +639,8 @@ transpose_b_operand ( tmp_warp_tile_n_coord_LUT >>= NumBitsPerStep; tmp_warp_tile_k_coord_LUT >>= NumBitsPerStep; + static_assert(TilesPerWarp == 2); + // [warp_tile][n/k] const int warp_tile_coord[TilesPerWarp][2] = { // n k @@ -248,6 +653,7 @@ transpose_b_operand ( Tensor tCsB = sB_thr_copy.partition_S( flatten(s_tile(_, make_coord(warp_tile_coord[warp_tile][0], warp_tile_coord[warp_tile][1]))) ); // (CPY, CPY_N, CPY_K) + copy(sB_tiled_copy, tCsB, transpose_fragments[warp_tile]); } @@ -261,68 +667,81 @@ transpose_b_operand ( ); // (CPY, CPY_N, CPY_K) copy(sB_tiled_copy, transpose_fragments[warp_tile], tCsB_transposed); } - } // lock step } // loop warp_group_tile - } // if not use universal transposition - - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// Universal transposition, need warp_group sync between load and store. - /// The number of reg used depends on the input elementB. - ////////////////////////////////////////////////////////////////////////////////////////////////////////////// - else { - /* - In one copy step, a warp group would load WarpgroupTileSize * WarpgroupTileSize tile then store to transposed location. - In warp_group_tile, each warp holds Four WarpTileSize x WarpTileSize elements: - K - ------------ - | W0 W1 W2 W3 --- - | W0 W1 W2 W3 | - | W0 W1 W2 W3 | --> Copy Step 0 - | W0 W1 W2 W3 --- - .... - | W0 W1 W2 W3 --- - | W0 W1 W2 W3 | - | W0 W1 W2 W3 | --> Copy Step n - | W0 W1 W2 W3 --- - */ - static_assert((NumThreadsPerWarpGroup % WarpThreadShapeN == 0), "Unsupported warp thread layout."); - constexpr auto WarpgroupThreadLayout = make_layout(make_shape(Int{}, Int{})); - - // Get copy tile and partition to each thread - auto sB_tiled_copy = make_tiled_copy( - Copy_Atom{}, - WarpgroupThreadLayout, // thr_layout - Layout<_1>{} // val_layout - ); - static_assert(size(sB_tiled_copy) == size(TiledMma{}), "Wrong thread number in TiledCopy."); - - auto sB_thr_copy = sB_tiled_copy.get_thread_slice(warp_group_thread_idx); - Tensor tCsB = sB_thr_copy.partition_S( sB(_,_,smem_pipe_read.index())); // (CPY, CPY_N, CPY_K) - Tensor tCsB_transposed = sB_thr_copy.partition_D(gmma_sB(_,_,smem_pipe_read.index())); // (CPY, CPY_N, CPY_K) + } - // Divide partitioned tile to limit register usage - constexpr int CopySteps = size<0>(SmemLayoutB{}) / WarpgroupTileSize; - constexpr auto CopyTileShape = make_shape(size<0>(tCsB), Int< size<1>(tCsB) / CopySteps >{}, size<2>(tCsB)); - static_assert(size<1>(tCsB) % CopySteps == 0, "CopySteps must evenly divide rank 1 size of partitioned SMEM."); + CUTLASS_DEVICE void synchronize(int step) { + if (step == 0) { + // SMEM fence to make sure B is transposed before math + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + } + } - Tensor tCsB_copy_tile = zipped_divide(tCsB, CopyTileShape); - Tensor tCsB_copy_tile_transposed = zipped_divide(tCsB_transposed, CopyTileShape); - auto transpose_fragment = make_fragment_like(tCsB_copy_tile(_,_0{})); + CUTLASS_DEVICE void synchronize() { + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(size(TiledMma{}), 1); + } - CUTLASS_PRAGMA_NO_UNROLL - for (int step = 0; step < CopySteps; ++step) { - copy(sB_tiled_copy, tCsB_copy_tile(_,step), transpose_fragment); + template < + class TensorSmemB, + class TensorTransposedSmemB> + CUTLASS_DEVICE void transpose( + TensorSmemB const& sB, + TensorTransposedSmemB const& gmma_sB, + int read_stage) { + this->operator()(sB, gmma_sB, read_stage, 0); + synchronize(); + } - // Make sure all elements are read before being overwritten - __syncthreads(); +private: + const int warp_idx; + const int warp_group_thread_idx; + const int warp_idx_in_warp_group; + const int current_warp_tile_n_coord_LUT; + const int current_warp_tile_k_coord_LUT; +}; - copy(sB_tiled_copy, transpose_fragment, tCsB_copy_tile_transposed(_,step)); - } - } // if use universal transposition - // SMEM fence to make sure B is transposed before math - cutlass::arch::fence_view_async_shared(); +template< + class TiledMma, + class SmemLayoutB, + class SmemLayoutAtomB, + class ElementB, + bool TransposeB +> +constexpr CUTLASS_HOST_DEVICE +auto +make_transpose_operand_b( + int warp_idx, + int warp_group_thread_idx, + TiledMma, + SmemLayoutB, + SmemLayoutAtomB, + ElementB, + cute::bool_constant) +{ + if constexpr (!TransposeB) { + return NoTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else if constexpr (use_universal_transposition()) { + return UniversalTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else if constexpr (sizeof(ElementB) == 1) { + return AsyncTranspositionOperandB_1BElementB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } + else { + return AsyncTranspositionOperandB( + warp_idx, warp_group_thread_idx, TiledMma{}, + SmemLayoutB{}, SmemLayoutAtomB{}, ElementB{}); + } } }; // namespace detail diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index d56f95bdc5..68ad4f98a6 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -164,7 +164,6 @@ struct uint128_t { uint64_t overflow; y.hilo_.hi += _umul128(hilo_.hi, rhs, &overflow); #else - // TODO - not implemented CUTLASS_UNUSED(rhs); exception(); #endif @@ -182,7 +181,6 @@ struct uint128_t { uint64_t remainder = 0; quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else - // TODO - not implemented CUTLASS_UNUSED(divisor); exception(); #endif @@ -199,7 +197,6 @@ struct uint128_t { // implemented using MSVC's arithmetic intrinsics (void)_udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else - // TODO - not implemented CUTLASS_UNUSED(divisor); exception(); #endif @@ -217,7 +214,6 @@ struct uint128_t { // implemented using MSVC's arithmetic intrinsics quotient = _udiv128(hilo_.hi, hilo_.lo, divisor, &remainder); #else - // TODO - not implemented CUTLASS_UNUSED(remainder); CUTLASS_UNUSED(divisor); exception(); diff --git a/include/cutlass/workspace.hpp b/include/cutlass/workspace.hpp new file mode 100644 index 0000000000..35f0c0f8be --- /dev/null +++ b/include/cutlass/workspace.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for initializing workspaces +*/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include "cuda_runtime.h" + +#include "cutlass/trace.h" +#endif + +#include "cutlass.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) +static Status +zero_workspace(void* workspace, int workspace_size, cudaStream_t stream = nullptr) { + if (workspace_size > 0) { + if (workspace == nullptr) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + CUTLASS_TRACE_HOST(" clearing barrier workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_size, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; +} +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/media/docs/cute/00_quickstart.md b/media/docs/cute/00_quickstart.md index df7ceadc7e..a9c35f1be3 100644 --- a/media/docs/cute/00_quickstart.md +++ b/media/docs/cute/00_quickstart.md @@ -69,7 +69,10 @@ Other files in this directory discuss specific parts of CuTe. * [`0t_mma_atom.md`](./0t_mma_atom.md) demonstrates CuTe's meta-information and interface to our GPUs' architecture-specific Matrix Multiply-Accumulate (MMA) instructions. -* [`0x_gemm_tutorial.md`](./0x_gemm_tutorial.md) provides a walkthrough of building a GEMM from scratch using CuTe. +* [`0x_gemm_tutorial.md`](./0x_gemm_tutorial.md) walks through building a GEMM from scratch using CuTe. * [`0y_predication.md`](./0y_predication.md) explains what to do if a tiling doesn't fit evenly into a matrix. + +* [`0z_tma_tensors.md`](./0z_tma_tensors.md) summarizes + how CuTe supports TMA loads and stores. diff --git a/media/docs/cute/03_tensor.md b/media/docs/cute/03_tensor.md index 2382d834f7..ccd25ae3ee 100644 --- a/media/docs/cute/03_tensor.md +++ b/media/docs/cute/03_tensor.md @@ -127,7 +127,8 @@ This results in ### CuTe's provided `Engine` types -CuTe comes with three `Engine` types. +CuTe comes with a few `Engine` types. +Here are the three that new users are most likely to encounter first. * `ArrayEngine`: an owning `Engine`, representing an array of `N` elements of type `T` diff --git a/media/docs/cute/0z_tma_tensors.md b/media/docs/cute/0z_tma_tensors.md new file mode 100644 index 0000000000..3e0d0b1c7d --- /dev/null +++ b/media/docs/cute/0z_tma_tensors.md @@ -0,0 +1,83 @@ +# TMA tensors + +TMA tensors have three differences from +"ordinary" global memory tensors. + +1. The tensor's iterator stores a base coordinate, + not a pointer. + +2. The tensor's actual global memory pointer + does not live in the tensor. + Instead, it lives in a TMA descriptor, + which is stored in the TMA `Copy_Traits` specialization. + +3. The tensor's strides aren't just integers. + Instead, they are linear combinations of "basis functions." + +The following sections will elaborate these differences. + +## Iterator stores a base coordinate, not a pointer + +"Ordinary" tensors of global memory have an iterator type +(the "Engine" template parameter) that wraps a pointer. +For example, `gmem_ptr` wraps a `T*`. +A TMA tensor's iterator type is `ArithmeticTupleIterator`. +`ArithmeticTupleIterator` stores a coordinate +(a tuple of integers) instead of a pointer. +The coordinate is represented as an `ArithmeticTuple`, +which is just a (public subclass of) `cute::tuple` +that has an overloaded `operator+`. +The sum of two tuples is the tuple of the sum of the elements. + +When we perform the TMA load or store, +the iterator's coordinate goes into the PTX instruction. +(For TMA specializations of `Copy_Traits`, +this happens in the `private` member function `copy_unpack_`.) +The coordinate represents the tensor's "base coordinate." +For tiled TMA, the base coordinate of the whole tensor +might start out as (0, 0, ..., 0). However, slicing the tensor +might result in a different base coordinate. +For im2col TMA load, the base coordinate is the lower corner. + +## Pointer lives in TMA descriptor, not tensor + +The TMA descriptor has the actual pointer to global memory in it. +Storing the TMA descriptor in the tensor would make tensors +expensive to copy and slice, as the TMA descriptor is 128 bytes. +Instead, we store the TMA descriptor +in the `Copy_Traits` specialization. + +## Tensor's strides aren't just integers + +For "ordinary" tensors, the layout takes a coordinate +`(i, j)` as input, and returns a single integer offset `k`. +The resulting pointer-to-element +is the base pointer, plus the offset k. +However, TMA loads and stores don't take a pointer. +They take a TMA descriptor, and a coordinate `(i, j)`. +Building the strides out of "basis functions" +is the trick to make the layout return a coordinate -- +a tuple of integers -- instead of just a single integer offset. +A "basis function" for strides +is a lot like a basis function for Euclidean space, +except that strides' basis functions can be hierarchical. + +Layouts work by taking the inner product +of their input coordinate with the strides. +For "ordinary" integer strides, e.g., `(1, 100)`, +the inner product of the input coordinate `(i, j)` +and the strides is `i + 100j`. +That gives the formula for the offset. +For strides built of basis functions, for example, +if the strides are `(_1@0, _1@1)`, +then the inner product of the input coordinate `(i, j)` +with the strides is `i@0 + j@1`. +The `i` here is a coefficient of the basis function `@0`, +and `j` is a coefficient of the basis function `@1`. +The result is a vector sum. We _interpret_ this result as +"the zeroth coefficient is i, and the first coefficient is j." +That translates into the (TMA) coordinate `(i, j)`. +If we wanted to reverse the coordinates, +then we could use `(_1@1, _1@0)` as the strides. +Evaluating the layout would give `i@1 + j@0`, +that is, `(j, i)`. diff --git a/media/docs/cutlass_3x_backwards_compatibility.md b/media/docs/cutlass_3x_backwards_compatibility.md index 723d783d2e..354e70dd48 100644 --- a/media/docs/cutlass_3x_backwards_compatibility.md +++ b/media/docs/cutlass_3x_backwards_compatibility.md @@ -101,7 +101,7 @@ template < class ProblemShapeOrThreadblockMma_, class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, - class GridSwizzle_ = void, + class TileScheduler_ = void, class Enable = void > class GemmUniversal; diff --git a/media/docs/gemm_api_3x.md b/media/docs/gemm_api_3x.md index 04f4215219..8197d2e721 100644 --- a/media/docs/gemm_api_3x.md +++ b/media/docs/gemm_api_3x.md @@ -379,7 +379,7 @@ may also change in the future as we adopt user feedback. If the builder is able to provide a collective mainloop type for the given set of parameters, it will be aliased within as `CollectiveOp`. For more information on how to -parameterize kernels conveniently with the collective builder, please see example [49_hopper_gemm_schedules_with_collective_builder](49_hopper_gemm_schedules_with_collective_builder). +parameterize kernels conveniently with the collective builder, please see example [49_hopper_gemm_with_collective_builder](/examples/49_hopper_gemm_with_collective_builder). ### Epilogue @@ -387,7 +387,7 @@ The collective epilogue implements element-wise operations involving the output matrix. Users can provide a custom epilogue, or use one of the standard epilogues. These live in the directory -[include/cutlass/epilogue/collective/](../../include/cutlass/epilogue/collective/), +[include/cutlass/epilogue/collective/](/include/cutlass/epilogue/collective/), and include classes like `cutlass::epilogue::collective::DefaultEpilogue` and @@ -415,7 +415,7 @@ epilogues, and/or other operations. The entry point API for CUTLASS 3.0 kernel is the class `cutlass::gemm::kernel::GemmUniversal`, found in the header file -[include/cutlass/gemm/kernel/gemm_universal.hpp](../../include/cutlass/gemm/kernel/gemm_universal.hpp). +[include/cutlass/gemm/kernel/gemm_universal.hpp](/include/cutlass/gemm/kernel/gemm_universal.hpp). `GemmUniversal` is a stateless universal device kernel that implements GEMM as the composition of two parts: @@ -442,7 +442,7 @@ template < class ProblemShapeOrThreadblockMma_, // (m, n, k) or (m, n, k, l) class CollectiveMainloopOrEpilogue_, class CollectiveEpilogueOrThreadblockSwizzle_, - class GridSwizzle_ = void, + class TileScheduler_ = void, class Enable = void > class GemmUniversal; @@ -475,24 +475,24 @@ We will explain *collective* in more detail below. Specializations of `kernel::GemmUniversal` for 3.0 APIs live in any of various `gemm_*.hpp` files in the directory -[include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/). +[include/cutlass/gemm/kernel/](/include/cutlass/gemm/kernel/). Specializations for 2.x APIs can be found in the header file -[include/cutlass/gemm/kernel/gemm_universal.h](../../include/cutlass/gemm/kernel/gemm_universal.h). +[include/cutlass/gemm/kernel/gemm_universal.h](/include/cutlass/gemm/kernel/gemm_universal.h). CUTLASS 3.x implements various embodiments of `kernel::GemmUniversal`. Each kernel layer schedule is specialized for a GEMM scheduling algorithm and GPU architecture. Specializations of `kernel::GemmUniversal` for 3.0 APIs live in any of various `include/cutlass/gemm/kernel/{arch_tag}*.hpp` files in the directory -[include/cutlass/gemm/kernel/](../../include/cutlass/gemm/kernel/). +[include/cutlass/gemm/kernel/](/include/cutlass/gemm/kernel/). Which specialization to dispatch to is decided through the dispatch policy's `Schedule` type. For example, the header file -[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) +[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) has a specialization of `kernel::GemmUniversal` for Hopper that uses a warp-specialized mainloop with a persistent scheduling algorithm, while the header file -[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp](../../include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) +[include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) has a specialization of `GemmUniversal` for Hopper that uses a warp-specialized but non-persistent algorithm. @@ -510,13 +510,13 @@ template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, - class GridSwizzle_ + class TileScheduler_ > class GemmUniversal< ProblemShape_, CollectiveMainloop_, CollectiveEpilogue_, - GridSwizzle_, + TileScheduler_, std::enable_if_t>> ``` diff --git a/media/docs/programming_guidelines.md b/media/docs/programming_guidelines.md index aba270aa4a..4be52bf57d 100644 --- a/media/docs/programming_guidelines.md +++ b/media/docs/programming_guidelines.md @@ -55,7 +55,6 @@ structure should also include a data member corresponding to each data member in be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument. - ### Composable Shared Memory Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm @@ -155,18 +154,18 @@ When declaring functions, indent function parameters like this. ```c++ void possibly_an_unusually_long_function_name( - std::uint32_t foo - std::uint32_t const* bar, - TypeA a, - TypeB b, - TypeC c) { + std::uint32_t foo + std::uint32_t const* bar, + TypeA a, + TypeB b, + TypeC c) { // ... the function's body ... } ``` A newline should not be inserted between the parenthesis that closes the function's parameters and the curly bracket -that opens the function's body. +that opens the function's body. Note the double indent for function parameters. #### If-else brackets and spacing @@ -743,10 +742,15 @@ These include for functions that run on the host and the device, * `CUTLASS_DEVICE` or `CUTE_DEVICE` - for functions that run on the device only, and + for functions that run on the device only, * `CUTE_HOST` - for functions that run on the host only; and + for functions that run on the host only, and + + * `CUTE_HOST_RTC` + for functions that run on the host only, + but occur as unevaluated operands (of e.g., `decltype` or `sizeof`; + see C++ Standard, `[expr.context]` 1) in device code; and * annotations to loop unrolling: @@ -759,6 +763,20 @@ These include Use `#pragma once` to guard all headers. +### CuTe Layout Comments + +* Right align CuTe layout comments at column 120. +* If layout comment is too long do your best to align it. +* If layout comment is too long and there are many related tensors that reader should read together, try to align the layout comments of related tensors. + +```c++ + Tensor my_tensor = make_tensor(Layout{}, Stride<_1,_2>>{}); // (2,2):(1,2) + + // Related tensors + Tensor my_tensor1 = make_tensor(ThisIsAVeryComplicatedLayoutWithAVeryLongName); // ((Mode0_0,Mode0_1,Mode0_2),Mode1,Mode2,Mode3) + Tensor my_tensor2_related = make_tensor(ThisIsAVeryComplicatedLayoutWithAVeryLongName); // ((Mode0_0,Mode0_1,Mode0_2),Mode1,Mode2,Mode3) +``` + ### CUDA C++ style #### CUDA Built-in Variables @@ -795,6 +813,26 @@ CuTe has replaced CUTLASS 2.x components such as [Layouts](layout.md), and [`TensorRef` and `TensorView`](layout.md#tensorref). +## CUTLASS idioms + +### Detecting major mode + +Developers sometimes need to detect whether a tensor is MN-major or K-major. +(For definitions, see the [CuTe GEMM tutorial](./cute/0x_gemm_tutorial.md).) + +* _Correct_: `cutlass::detail::is_major<0, Stride>()` or +`cutlass::detail::is_k_major()` from `include/cutlass/gemm/gemm.h` + +* _Incorrect_: `get<0>(stride) == 1` + +The second point is incorrect because it assumes that the mode +is a single integer, not a multimode. +This means that the code will fail to compile for tensor contractions. +For example, suppose that a tensor A +has shape `((X, Y), K)` and stride `((1, X), X*Y)`. +`get<0>(stride)` is the tuple `(1, X)`, not a single integer. +However, A is certainly M major if interpreted as a matrix. + # Copyright Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 1f92a91ab6..c43882cc4d 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -425,10 +425,10 @@ int main(int argc, char const **args) { StrideC stride_C; StrideD stride_D; - stride_A = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})); - stride_B = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})); - stride_C = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})); - stride_D = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, Int<1>{})); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, Int<1>{})); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, Int<1>{})); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, Int<1>{})); block_A.reset(M * K); block_B.reset(K * N); diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index 03f96437bf..77724f7e83 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -71,12 +71,15 @@ def _cuda_install_path_from_nvcc() -> str: DataType, DataTypeSize, EpilogueFunctor, + EpilogueScheduleSuffixes, + EpilogueScheduleTag, + EpilogueScheduleType, GemmKind, LayoutTag, LayoutType, KernelScheduleSuffixes, - KernelScheduleType, KernelScheduleTag, + KernelScheduleType, MathInstruction, MathOperation, OpcodeClass, @@ -85,6 +88,9 @@ def _cuda_install_path_from_nvcc() -> str: SwizzlingFunctor, TensorDescription, TileDescription, + TileSchedulerSuffixes, + TileSchedulerTag, + TileSchedulerType ) this = sys.modules[__name__] @@ -106,11 +112,12 @@ def set_log_level(level: int): this.option_registry = OptionRegistry(device_cc()) -this.__version__ = '3.1.0' +this.__version__ = '3.2.0' from cutlass.backend import get_memory_pool from cutlass.emit.pytorch import pytorch from cutlass.op.gemm import Gemm +from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad from cutlass.op.gemm_grouped import GroupedGemm from cutlass.op.op import OperationBase diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index 7212e414ae..0a429bf8ad 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -118,6 +118,7 @@ class GenericMainloopArguments3x_(ctypes.Structure): ("stride_A", StrideBatched_), ("ptr_B", ctypes.c_void_p), ("stride_B", StrideBatched_), + ("mma_promotion_interval", ctypes.c_int) ] @@ -148,12 +149,14 @@ class _MainloopArgumentsTma(ctypes.Structure): ("stride_A", StrideBatched_), ("ptr_B", ctypes.c_void_p), ("stride_B", StrideBatched_), + ("mma_promotion_interval", ctypes.c_int) ] @staticmethod def from_generic_mainloop_args(args: GenericMainloopArguments3x_): return _MainloopArgumentsTma( args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, + args.mma_promotion_interval ) class _MainloopArgumentsMultistage(ctypes.Structure): @@ -203,15 +206,23 @@ class _EpilogueArguments(ctypes.Structure): ("stride_D", StrideBatched_), ] + class _HardwareInfo(ctypes.Structure): + _fields_ = [ + ("device_id", ctypes.c_int), + ("sm_count", ctypes.c_int) + ] + class _GemmArguments(ctypes.Structure): _fields_ = [ ("mode", ctypes.c_int), ("problem_size", GemmCoordBatched_), ("mainloop", mainloop_arguments), - ("epilogue", _EpilogueArguments) + ("epilogue", _EpilogueArguments), + ("hw_info", _HardwareInfo), + ("splits", ctypes.c_int) ] - return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams + return _GemmArguments, _EpilogueArguments, _EpilogueOutputOpParams, _HardwareInfo def get_gemm_arguments(epilogue_functor): diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index 3b9ae1a6e6..21117b031b 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -39,16 +39,37 @@ from cuda import cuda, nvrtc import cutlass_bindings -from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH +from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH, logger from cutlass.backend.gemm_operation import GemmOperationUniversal from cutlass.backend.library import ApiVersion from cutlass.backend.utils.device import device_cc from cutlass.backend.utils.software import SubstituteTemplate +import subprocess IncludeTemplate = r"""#include "${include}" """ +def compile_with_nvcc(cmd, source, error_file): + succeed = True + try: + subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=True) + except subprocess.CalledProcessError as e: + error_message = e.output.decode() + with open(error_file, "w") as error_out: + error_log = "Compilation error for the following kernel: \n" + error_log += source + error_log += "\nError Message:\n" + error_log += error_message + error_out.write(error_log) + succeed = False + if not succeed: + # Print the error log to stdout if log level is set to warning or higher + # verbosity. Otherwise, simply point to the error log file. + logger.warning(error_log) + raise Exception(f"Invalid Kernel. See '{error_file}' for details.") + + class CompilationOptions: """ Compilation options. @@ -129,20 +150,24 @@ def __init__(self) -> None: connection.commit() cursor.close() + self._nvrtc_compile_options = ["-std=c++17", "-default-device"] + self._nvcc_compile_options = [ + "-std=c++17", + "--expt-relaxed-constexpr", + "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", + ] self.nvcc() self.compiled_cache_device = cutlass_bindings.CompileCache() self.compiled_cache_host = cutlass_bindings.CompileCache() def nvrtc(self): self.backend = "nvrtc" - self.default_compile_options = ["-std=c++17", "-default-device"] + self.default_compile_options = self._nvrtc_compile_options + def nvcc(self): self.backend = "nvcc" - self.default_compile_options = [ - "-std=c++17", - "--expt-relaxed-constexpr", - "-Xcudafe --diag_suppress=esa_on_defaulted_function_ignored", - ] + self.default_compile_options = self._nvcc_compile_options + def insert_operation(self, op_key, cubin, hostfile, op_name, op_attrs): connection = sqlite3.connect(CACHE_FILE) cursor = connection.cursor() @@ -200,7 +225,7 @@ def load_operation(self, op_key, extra_funcs): self.compiled_cache_host.insert(key, compiled_host_fns) return True - def emit_compile_(self, operation_list, compilation_options): + def emit_compile_(self, operation_list, compilation_options, host_compilation_options): """ Compile a list of kernels and store them into database """ @@ -299,7 +324,7 @@ def emit_compile_(self, operation_list, compilation_options): "tarfile": temp_cubin.name, } cmd = SubstituteTemplate(cmd_template, values) - os.system(cmd) + compile_with_nvcc(cmd, source_buffer_device, "./cutlass_python_compilation_device_error.txt") # load the cubin image with open(temp_cubin.name, "rb") as file: @@ -314,7 +339,7 @@ def emit_compile_(self, operation_list, compilation_options): cmd_template, { "cuda_install_path": CUDA_INSTALL_PATH, - "options": compilation_options.get_str(), + "options": host_compilation_options.get_str(), }, ) @@ -323,29 +348,31 @@ def emit_compile_(self, operation_list, compilation_options): prefix="host_func", suffix=".so", delete=True) cmd += " - -shared -o %s -lcudart -lcuda" % temp.name - os.system(cmd) + compile_with_nvcc(cmd, source_buffer_host, error_file="./cutlass_python_compilation_host_error.txt") host_lib = ctypes.CDLL(temp.name) return cubin_image, host_lib, temp - def add_module(self, operations, compile_options=None): + def add_module(self, operations, compile_options=None, bypass_cache=False): """ Insert a new compiled device module """ + include_paths = [ + CUDA_INSTALL_PATH + "/include", + CUTLASS_PATH + "/include", + CUTLASS_PATH + "/tools/util/include", + CUTLASS_PATH + "/python/cutlass/cpp/include", + ] + + if device_cc() is not None: + arch = device_cc() + else: + # Find the maximum arch tag among the provided operations and compile for that target. + # Since we are compiling to .cubin files, only one architecture may be specified. + arch = max([op.arch for op in operations]) + host_compile_options = CompilationOptions( + self._nvcc_compile_options, arch, include_paths) if compile_options is None: - include_paths = [ - CUDA_INSTALL_PATH + "/include", - CUTLASS_PATH + "/include", - CUTLASS_PATH + "/tools/util/include", - CUTLASS_PATH + "/python/cutlass/cpp/include", - ] - - if device_cc() is not None: - arch = device_cc() - else: - # Find the maximum arch tag among the provided operations and compile for that target. - # Since we are compiling to .cubin files, only one architecture may be specified. - arch = max([op.arch for op in operations]) compile_options = CompilationOptions( self.default_compile_options, arch, include_paths) # save the cubin @@ -357,7 +384,7 @@ def add_module(self, operations, compile_options=None): # step 1: check if the operation is in cache compiled_kernel = self.compiled_cache_device.at(key) - if compiled_kernel is None: + if compiled_kernel is None and not bypass_cache: hit = self.load_operation(key, getattr( operation.rt_module, "extra_funcs", {})) if hit: compiled_kernel = self.compiled_cache_device.at(key) @@ -375,7 +402,7 @@ def add_module(self, operations, compile_options=None): if len(operation_list) > 0: cubin_image, host_lib, host_file = self.emit_compile_( - operation_list, compile_options) + operation_list, compile_options, host_compile_options) err, module = cuda.cuModuleLoadData(cubin_image) if err != cuda.CUresult.CUDA_SUCCESS: diff --git a/python/cutlass/backend/conv2d_operation.py b/python/cutlass/backend/conv2d_operation.py index 8dc55a25b1..977f5d4cba 100644 --- a/python/cutlass/backend/conv2d_operation.py +++ b/python/cutlass/backend/conv2d_operation.py @@ -41,6 +41,7 @@ from cutlass.backend.arguments import ArgumentBase from cutlass.backend.c_types import Conv2DProblemSize, TensorRef_, get_conv2d_arguments from cutlass.backend.library import ( + EmissionType, ConvKindNames, ConvKindTag, DataTypeNames, @@ -123,17 +124,17 @@ def __init__( super().__init__(A, B, C, D, **kwargs) # preprocessing output ops - if "output_op" in kwargs.keys() and split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel: - self.output_op = kwargs["output_op"] - else: - self.output_op = self.operation.epilogue_type(1.0, 0.0) - - if "split_k_slices" in kwargs.keys(): + if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1: self.split_k_mode = split_k_mode self.split_k_slices = kwargs["split_k_slices"] else: self.split_k_mode = cutlass_bindings.conv.SplitKMode.Serial self.split_k_slices = 1 + + if "output_op" in kwargs.keys() and self.split_k_mode != cutlass_bindings.conv.SplitKMode.Parallel: + self.output_op = kwargs["output_op"] + else: + self.output_op = self.operation.epilogue_type(1.0, 0.0) #: problem_size self.problem_size: cutlass_bindings.conv.Conv2dProblemSize = problem_size @@ -419,7 +420,9 @@ def __init__( C: TensorDescription, stride_support, epilogue_functor, - swizzling_functor=cutlass_bindings.IdentitySwizzle1 + swizzling_functor=cutlass_bindings.IdentitySwizzle1, + emission_type=EmissionType.Kernel, + **kwargs ): self.operation_kind: OperationKind = OperationKind.Conv2d self.arch: int = arch @@ -432,6 +435,8 @@ def __init__( self.iterator_algorithm = iterator_algorithm self.stride_support = stride_support self.swizzling_functor = swizzling_functor() + + self.emission_type = emission_type self.rt_module: Conv2dRT = Conv2dRT(self) self.argument_type = self.rt_module.argument_type @@ -562,6 +567,18 @@ def accumulator_type(self): return accum + def device_op(self): + """ + Returns a new Conv2dOperation object that is constructed with emission type + ``EmissionType.Device``. + + :return: operation ready for device-level code emission + :rtype: Conv2dOperation + """ + return Conv2dOperation( + self.conv_kind, self.iterator_algorithm, self.arch, self.tile_description, + self.A, self.B, self.C, self.stride_support, self.epilogue_functor, type(self.swizzling_functor), + emission_type=EmissionType.Device) ################################################################################################### # @@ -596,7 +613,7 @@ def __init__(self, operation_suffix=""): cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, ${epilogue_functor}, - ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${swizzling_functor}, ${stages}, ${math_operator}, ${iterator_algorithm}, @@ -608,6 +625,36 @@ def __init__(self, operation_suffix=""): struct ${operation_name}${operation_suffix}: public ${operation_name}_base { }; +""" + + self.template_device = """ +// Conv2d operation ${operation_name} + +using Conv2d${conv_kind_name}Kernel = typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} +>::Kernel; + +using DeviceKernel = + typename cutlass::conv::device::ImplicitGemmConvolution; """ def emit(self, operation): @@ -651,5 +698,10 @@ def emit(self, operation): "align_a": str(operation.A.alignment), "align_b": str(operation.B.alignment), } + + if operation.emission_type == EmissionType.Kernel: + conv2d_template = self.template + else: + conv2d_template = self.template_device - return SubstituteTemplate(self.template, values) + return SubstituteTemplate(conv2d_template, values) diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index c10056df11..706a1467f3 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -39,7 +39,17 @@ import numpy as np import rmm -from cutlass import KernelScheduleSuffixes, KernelScheduleTag, KernelScheduleType +from cutlass import ( + EpilogueScheduleSuffixes, + EpilogueScheduleTag, + EpilogueScheduleType, + KernelScheduleSuffixes, + KernelScheduleTag, + KernelScheduleType, + TileSchedulerSuffixes, + TileSchedulerTag, + TileSchedulerType +) from cutlass.backend.arguments import ArgumentBase from cutlass.backend.c_types import ( GemmCoord_, @@ -55,6 +65,7 @@ ) from cutlass.backend.library import ( ApiVersion, + EmissionType, ComplexTransformTag, DataTypeNames, DataTypeSize, @@ -548,6 +559,7 @@ def get_arguments(self): stride_A, int(self.ptr_B), stride_B, + 4 # mma_promotion_interval ) # Set of mainloop arguments needed for this kernel @@ -561,11 +573,15 @@ def get_arguments(self): stride_D, ) + # Set hardware info + hw_info = self.operation.rt_module.hw_info(0, device_sm_count()) + self.arguments = self.operation.argument_type( self.gemm_mode, problem_size_, mainloop, epilogue, + hw_info, ) return self.arguments @@ -1102,6 +1118,11 @@ class GemmRTUniversal3x(GemmRTUniversal): using GemmType = ${operation_name}_base; + // Get the workspace size + uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* argument) { + return GemmType::get_workspace_size(*argument); + } + // Get the params as byte array char* ${operation_name}_get_params(GemmType::Arguments* argument, int* workspace){ GemmType::Params params = GemmType::to_underlying_arguments(*argument, workspace); @@ -1118,7 +1139,7 @@ class GemmRTUniversal3x(GemmRTUniversal): uint64_t ${operation_name}_get_persistent_tiled_blk_shape_mnl(GemmType::ProblemShape problem) { auto problem_shape_MNKL = append<4>(problem, Int<1>{}); auto [problem_blocks_m, problem_blocks_n, problem_blocks_l] = - cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_blk_shape_mnl( + cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::get_tiled_cta_shape_mnl( problem_shape_MNKL, GemmType::TileShape{}, GemmType::DispatchPolicy::ClusterShape{}); return problem_blocks_m * problem_blocks_n * problem_blocks_l; } @@ -1141,7 +1162,8 @@ def __init__(self, operation: "GemmOperation"): self.extra_funcs = { "get_grid_shape": dim3_, "get_block_shape": dim3_, - "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64 + "get_persistent_tiled_blk_shape_mnl": ctypes.c_uint64, + "get_kernel_workspace_size": ctypes.c_uint64 } self.emitter = EmitGemmUniversalInstance3x("_type") self.mainloop_args = get_mainloop_arguments_3x( @@ -1151,7 +1173,10 @@ def __init__(self, operation: "GemmOperation"): operation.A.alignment, operation.B.alignment ) - self.argument_type, self.epilogue_args, self.epilogue_type = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor) + self.argument_type, self.epilogue_args, self.epilogue_type, self.hw_info = get_gemm_arguments_3x(self.mainloop_args, operation.epilogue_functor) + + def get_device_workspace_size(self, arguments: GemmArguments3x): + return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments())) class EmitGemmUniversalInstance3x: @@ -1183,7 +1208,7 @@ def __init__(self, operation_suffix=""): ${element_accumulator}, ${element_epilogue}, ${element_c}, ${layout_c}, ${align_c}, ${element_d}, ${layout_d}, ${align_d}, - cutlass::epilogue::collective::EpilogueScheduleAuto + ${epilogue_schedule} >::CollectiveOp; using CollectiveMainloop = @@ -1202,7 +1227,8 @@ def __init__(self, operation_suffix=""): using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, - CollectiveEpilogue + CollectiveEpilogue, + ${tile_scheduler} >; // Define named type @@ -1233,9 +1259,15 @@ def emit(self, operation): else: gemm_template = self.gemm_template_device - schedule = KernelScheduleType.ScheduleAuto + kschedule = KernelScheduleType.ScheduleAuto + eschedule = EpilogueScheduleType.ScheduleAuto + tschedule = TileSchedulerType.Default if operation.tile_description.kernel_schedule is not None: - schedule = operation.tile_description.kernel_schedule + kschedule = operation.tile_description.kernel_schedule + if operation.tile_description.epilogue_schedule is not None: + eschedule = operation.tile_description.epilogue_schedule + if operation.tile_description.tile_scheduler is not None: + tschedule = operation.tile_description.tile_scheduler values = { "operation_name": operation.procedural_name(), @@ -1264,7 +1296,9 @@ def emit(self, operation): "align_c": str(operation.C.alignment), "align_d": str(operation.C.alignment), "stage_count_type": stage_count_type, - "kernel_schedule": KernelScheduleTag[schedule], + "kernel_schedule": KernelScheduleTag[kschedule], + "epilogue_schedule": EpilogueScheduleTag[eschedule], + "tile_scheduler": TileSchedulerTag[tschedule] } values["epilogue_functor"] = operation.epilogue_functor.emit() @@ -1382,15 +1416,6 @@ def get_workspace_size(self, arguments): ################################################################################ -class EmissionType(enum.Enum): - """ - Tags for whether to emit a kernel- or device-level operation - """ - - Kernel = enum_auto() - Device = enum_auto() - - class GemmOperationBase: """ CUTLASS GEMM operation @@ -1595,11 +1620,18 @@ def kernel_schedule_name_3x(self): else: return KernelScheduleSuffixes[self.tile_description.kernel_schedule] + # Generates a short string representing underlying epilogue schedule type + def epilogue_schedule_name_3x(self): + if self.tile_description.epilogue_schedule is None: + return EpilogueScheduleSuffixes[EpilogueScheduleType.ScheduleAuto] + else: + return EpilogueScheduleSuffixes[self.tile_description.epilogue_schedule] + def procedural_name(self): """The full procedural name indicates architecture, extended name, tile size, and layout.""" opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] if self.api == ApiVersion.v3x and self.arch >= 90: - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}" + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" return kernel_name_template.format( p=self.prefix, ar=self.arch, @@ -1614,7 +1646,8 @@ def procedural_name(self): l=self.tile_description.stages, s=self.layout_name_3x(), al=str(self.A.alignment), - k=self.kernel_schedule_name_3x() + k=self.kernel_schedule_name_3x(), + e=self.epilogue_schedule_name_3x() ) else: threadblock = self.tile_description.procedural_name() diff --git a/python/cutlass/backend/library.py b/python/cutlass/backend/library.py index 7760f6ed2b..18b56b030f 100644 --- a/python/cutlass/backend/library.py +++ b/python/cutlass/backend/library.py @@ -38,7 +38,7 @@ import enum import cutlass_bindings -from cutlass import KernelScheduleType +from cutlass import EpilogueScheduleType, KernelScheduleType, TileSchedulerType # The following block implements enum.auto() for Python 3.5 variants that don't include it such @@ -554,7 +554,9 @@ def __init__( warp_count, math_instruction, cluster_shape=[1, 1, 1], - kernel_schedule: KernelScheduleType = None + kernel_schedule: KernelScheduleType = None, + epilogue_schedule: EpilogueScheduleType = None, + tile_scheduler: TileSchedulerType = None, ): """ :param threadblock_shape: shape of a threadblock tyle @@ -568,18 +570,61 @@ def __init__( :type math_instruction: MathInstruction :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster :param kernel_schedule: type of kernel schedule to use (only available for SM90+) - :type kernel_schedule: cutlass.backend.KernelScheduleType + :type kernel_schedule: cutlass.KernelScheduleType + :param epilogue_schedule: type of epilogue schedule to use (only available for SM90+) + :type epilogue_schedule: cutlass.EpilogueScheduleType + :param tile_scheduler: type of tile scheduler to use (only available for SM90+) + :type tile_scheduler: cutlass.TileSchedulerType """ + if ((kernel_schedule is None and epilogue_schedule is not None) or + (kernel_schedule is not None and epilogue_schedule is None)): + raise Exception("Kernel and epilogue schedule must either both be Auto or neither be Auto.") + self.threadblock_shape = threadblock_shape self.cluster_shape = cluster_shape self.kernel_schedule = kernel_schedule - self.stages: int = stages + self.epilogue_schedule = epilogue_schedule + self.tile_scheduler = tile_scheduler + self.stages = stages self.math_instruction = math_instruction + self.instruction_shape = math_instruction.instruction_shape # Number of warps along x, y, z directions self.warp_count = warp_count + def clone_and_update(self, td: dict): + attrs = { + "cluster_shape": None, + "threadblock_shape": None, + "warp_count": None, + "stages": None, + "instruction_shape": None, + "kernel_schedule": None, + "epilogue_schedule": None, + "tile_scheduler": None + } + for key in attrs.keys(): + if key in td.keys(): + attrs[key] = td[key] + else: + attrs[key] = getattr(self, key) + + mi = MathInstruction( + attrs["instruction_shape"], + self.math_instruction.element_a, + self.math_instruction.element_b, + self.math_instruction.element_accumulator, + self.math_instruction.opcode_class, + self.math_instruction.math_operation + ) + + return TileDescription( + attrs["threadblock_shape"], attrs["stages"], + attrs["warp_count"], mi, attrs["cluster_shape"], + attrs["kernel_schedule"], attrs["epilogue_schedule"] + ) + @property def num_threads(self): """ @@ -622,16 +667,30 @@ def __str__(self): :return: contents of tile description :rtype: str """ - schedule = KernelScheduleType.ScheduleAuto if self.kernel_schedule is not None: - schedule = self.kernel_schedule + kschedule = self.kernel_schedule + else: + kschedule = KernelScheduleType.ScheduleAuto + + if self.epilogue_schedule is not None: + eschedule = self.epilogue_schedule + else: + eschedule = EpilogueScheduleType.ScheduleAuto + + if self.tile_scheduler is not None: + tschedule = self.tile_scheduler.name + else: + tschedule = "None" return f""" {{ ClusterShape: {self.cluster_shape} ThreadblockShape: {self.threadblock_shape} WarpCount: {self.warp_count} Stages: {self.stages if self.stages is not None else 'Auto'} - Kernel schedule: {schedule.name} + InstructionShape: {self.math_instruction.instruction_shape} + Kernel schedule: {kschedule.name} + Epilogue schedule: {kschedule.name} + TileScheduler: {tschedule} }}""" @@ -712,3 +771,12 @@ def api_version(arch, opclass, datatype): return ApiVersion.v3x else: return ApiVersion.v2x + + +class EmissionType(enum.Enum): + """ + Tags for whether to emit a kernel- or device-level operation + """ + + Kernel = enum_auto() + Device = enum_auto() diff --git a/python/cutlass/backend/test/conv2d_testbed.py b/python/cutlass/backend/test/conv2d_testbed.py index b6c960555f..3715b47515 100644 --- a/python/cutlass/backend/test/conv2d_testbed.py +++ b/python/cutlass/backend/test/conv2d_testbed.py @@ -38,7 +38,7 @@ import cutlass_bindings import numpy as np -from cutlass.backend.compiler import ArtifactManager +from cutlass.backend import compiler from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation from cutlass.backend.library import DataTypeSize, ShortDataTypeNames, StrideSupport from cutlass.backend.memory_manager import get_allocated_size @@ -127,7 +127,6 @@ def getTensorView(tensor, tensor_layout, conv_kind, problem_size, operand): raise ValueError("unsupported data type") -# @typechecked class Conv2dLauncher: """ Launcher that runs the operation on given problem size @@ -142,6 +141,7 @@ def __init__( profiling=False, warmup_iterations=500, iterations=500, + compilation_mode="nvcc", **kwargs, ) -> None: self.enable_cached_results = True @@ -176,7 +176,14 @@ def __init__( # Compile the operator # - ArtifactManager().add_module([operation, self.reduction_operation]) + if compilation_mode == "nvcc": + compiler.nvcc() + elif compilation_mode == "nvrtc": + compiler.nvrtc() + else: + raise Exception(f"Unexpected compilation mode {compilation_mode}") + + compiler.add_module([operation, self.reduction_operation]) self.operation = operation @@ -195,14 +202,14 @@ def __init__( element_size = DataTypeSize[operation.A.element] if element_size <= 8: - self.scope = 1 + self.randomization_max = 1 elif element_size == 16: if accumulator_size <= 16: - self.scope = 2 + self.randomization_max = 2 else: - self.scope = 4 + self.randomization_max = 4 else: - self.scope = 7 + self.randomization_max = 7 # Seed self.seed = seed @@ -263,12 +270,12 @@ def uniform_init(self, size, dtype): if dtype in [np.float32, np.float16, bfloat16, np.float64]: return np.ceil( np.random.uniform( - low=-self.scope - 0.5, high=self.scope - 0.5, size=size + low=-self.randomization_max - 0.5, high=self.randomization_max - 0.5, size=size ).astype(dtype) ) else: return np.random.uniform( - low=-self.scope - 1, high=self.scope + 1, size=size + low=-self.randomization_max - 1, high=self.randomization_max + 1, size=size ).astype(dtype) def eq_gemm_size(self, problem_size): @@ -624,13 +631,15 @@ def run( ############################################################################################################ -def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved=False): +def test_all_conv2d_from_compilation_mode( + operation: Conv2dOperation, + conv_test_sizes, + interleaved, + compilation_mode): + passed = True - # - # Testbed object - # - testbed = Conv2dLauncher(operation, interleaved=interleaved) + testbed = Conv2dLauncher(operation, interleaved=interleaved, compilation_mode=compilation_mode) # # Get conv problem sizes to run conv operator @@ -781,3 +790,18 @@ def test_all_conv2d(operation: Conv2dOperation, conv_test_sizes=[], interleaved= ) return passed + + +def test_all_conv2d( + operation: Conv2dOperation, + conv_test_sizes=[], + interleaved=False, + compilation_modes=["nvcc", "nvrtc"]): + + for compilation_mode in compilation_modes: + passed = test_all_conv2d_from_compilation_mode(operation, conv_test_sizes, interleaved, compilation_mode) + + if not passed: + return False + + return True diff --git a/python/cutlass/backend/test/gemm_testbed.py b/python/cutlass/backend/test/gemm_testbed.py index 3790f17055..a52c41bd85 100644 --- a/python/cutlass/backend/test/gemm_testbed.py +++ b/python/cutlass/backend/test/gemm_testbed.py @@ -177,6 +177,7 @@ def __init__( profiling=False, warmup_iterations=500, iterations=500, + compiler_mode: str = "nvcc", **kwargs, ) -> None: # create the reduction kernel @@ -209,13 +210,19 @@ def __init__( # # Compile the operator # + if compiler_mode == "nvcc": + compiler.nvcc() + elif compiler_mode == "nvrtc": + compiler.nvrtc() + else: + raise Exception(f"Unexpected compiler string {compiler_mode}") op_list = [operation] if operation.arch < 90: # Split K via Python is currently only supported for pre-SM90 kernels op_list.append(self.reduction_operation) - compiler.add_module(op_list) + compiler.add_module(op_list, bypass_cache=True) self.operation = operation @@ -603,7 +610,7 @@ def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, be return passed -def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"): +def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"): passed = True minimum_operand_element_size = min( @@ -711,7 +718,7 @@ def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal"): problem_alpha = [1.0] problem_beta = [2.0] - testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved")) + testbed = GemmUniversalLauncher(operation, interleaved=(testcase == "interleaved"), compiler_mode=compilation_mode) for mode in modes: for m in problem_size_m: diff --git a/python/cutlass/backend/test/utils.py b/python/cutlass/backend/test/utils.py index 1489a4aa0f..7aa9b211aa 100644 --- a/python/cutlass/backend/test/utils.py +++ b/python/cutlass/backend/test/utils.py @@ -30,10 +30,13 @@ # ################################################################################################# +import cutlass import cutlass_bindings -from cutlass import KernelScheduleSuffixes +from cutlass import EpilogueScheduleSuffixes, KernelScheduleSuffixes +from cutlass.utils.datatypes import binding_opclass, binding_type from cutlass.backend import library +from cutlass.backend.test.gemm_testbed import test_all_gemm from cutlass.backend.utils.software import SubstituteTemplate @@ -75,6 +78,7 @@ def get_name( arch, opclass, kernel_schedule=None, + epilogue_schedule=None, suffix="", ): """ @@ -97,24 +101,26 @@ def get_name( :type opclass: cutlass_bindings.OpClass :param kernel_schedule: kernel_schedule type :type kernel_schedule: cutlass.KernelScheduleType + :param epilogue_schedule: epilogue_schedule type + :type epilogue_schedule: cutlass.EpilogueScheduleType :param suffix: additional string to add to the suffix of the name :type suffix: str :return: str """ - name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${suffix}" + name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" return SubstituteTemplate( name_format, { "arch": str(arch), - "eA": library.DataTypeNames[element_a], - "eB": library.DataTypeNames[element_b], - "eC": library.DataTypeNames[element_output], + "eA": library.DataTypeNames[binding_type(element_a)], + "eB": library.DataTypeNames[binding_type(element_b)], + "eC": library.DataTypeNames[binding_type(element_output)], "lA": library.ShortLayoutTypeNames[layouts[0]], "lB": library.ShortLayoutTypeNames[layouts[1]], "lC": library.ShortLayoutTypeNames[layouts[2]], - "opclass": library.OpcodeClassNames[opclass], - "acc": library.DataTypeNames[element_accumulator], + "opclass": library.OpcodeClassNames[binding_opclass(opclass)], + "acc": library.DataTypeNames[binding_type(element_accumulator)], "cM": str(cluster_shape[0]), "cN": str(cluster_shape[1]), "cK": str(cluster_shape[2]), @@ -126,6 +132,174 @@ def get_name( "aB": str(alignments[1]), "aC": str(alignments[2]), "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule], + "e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule], "suffix": "" if suffix is None else suffix, }, ) + +def get_name_conv2d( + arch, + conv_kind, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm, + swizzle, + split_k_mode, + split_k_slices, + activation +): + """ + Generates a procedural name for a test case for conv2d + + :param arch: compute capability of kernel being generated + :type arch: int + :param conv_kind: the convolution type (i.e. fprop, dgrad, wgrad) + :type conv_kind: str + :param iterator_algorithm: the iterator algorithm applied + :type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param element_accumulator: data type used in accumulation + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_bindings.OpClass + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param stride_support: stride support of dgrad + :param alignment: int + :type alignment: int + + :return: str + """ + if iterator_algorithm is None: + iterator_algorithm = "AUTO" + if swizzle is None: + swizzle = 1 + name_format = "test_SM${arch}_Device_Conv2d_${conv_kind}_${iter_alg}_ImplicitGemm_${eA}nhwc_${eB}nhwc_${eC}nhwc_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${wM}x${wN}x${wK}_${IM}${IN}${IK}_stage${stages}_swizzle${swizzle}_${split_k_mode}${split_k_slices}_${activation}" + + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "conv_kind": conv_kind, + "iter_alg": iterator_algorithm, + "eA": library.DataTypeNames[binding_type(element)], + "eB": library.DataTypeNames[binding_type(element)], + "eC": library.DataTypeNames[binding_type(element_output)], + "opclass": opclass, + "acc": library.DataTypeNames[binding_type(element_accumulator)], + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "wM": str(threadblock_shape[0] // warp_count[0]), + "wN": str(threadblock_shape[1] // warp_count[1]), + "wK": str(threadblock_shape[2] // warp_count[2]), + "IM": str(instruction_shape[0]), + "IN": str(instruction_shape[1]), + "IK": str(instruction_shape[2]), + "stages": str(stages), + "swizzle": str(swizzle), + "split_k_mode": split_k_mode, + "split_k_slices": str(split_k_slices), + "activation": activation + } + ) + + +def add_test_gemm( + cls=None, + cc=None, + element=None, + layouts=None, + alignments=None, + element_output=None, + element_accumulator=None, + cluster_shape=None, + threadblock_shape=None, + warp_count=None, + stages=None, + opclass=None, + swizzle=None, + kernel_schedule=None, + epilogue_schedule=None, + compilation_modes=['nvcc', 'nvrtc']): + """ + Create test-running functions with the given specification and set it as a method of ``cls``. + + :param cls: class to which the generated method will be added + :type cls: type + :param cc: compute capability to compile for + :type cc: int + :param element: data type of A and B operands + :type element: cutlass.DataType.f16 + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass.DataType + :param cluster_shape: dimensions of clusters + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass.OpClass + :param swizzle: threadblock swizzling functor + :param kernel_schedule: kernel schedule to use + :type kernel_schedule: cutlass.KernelScheduleType + :param epilogue_schedule: epilogue schedule to use + :type epilogue_schedule: cutlass.EpilogueScheduleType + :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') + :type compilation_modes: list + """ + + for compilation_mode in compilation_modes: + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + element_A = element + element_B = element + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_output, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + td = plan.tile_descriptions()[0] + td.threadblock_shape = threadblock_shape + td.stages = stages + if warp_count is not None: + td.warp_count = warp_count + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode)) + + element_epilogue = element_accumulator + name = get_name( + layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator, + element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape, + stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass, + kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}') + + setattr(cls, name, run) diff --git a/python/cutlass/cpp/include/swizzling.h b/python/cutlass/cpp/include/swizzling.h index 27994dedd8..e306624597 100644 --- a/python/cutlass/cpp/include/swizzling.h +++ b/python/cutlass/cpp/include/swizzling.h @@ -135,9 +135,8 @@ void bind_dgrad_swizzle(py::module & m, std::string name) { :param problem_size: Implicit gemm problem size conv_operator(NPQK, NHWC, KRSC) :type problem_size: :class:`cutlass.gemm.GemmCoord`) )pbdoc") - .def("get_grid_shape", [](const T & swizzle, cutlass::gemm::GemmCoord tiled_shape) { - return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); - }, py::arg("tiled_shape"), + .def("get_grid_shape", &T::get_grid_shape, + py::arg("tiled_shape"), R"pbdoc(Computes CUDA grid dimensions given a size in units of logical tiles)pbdoc") .def("tag", [](const T & swizzle){ return demangle(typeid(T).name()); diff --git a/python/cutlass/cpp/test/conv/host.h b/python/cutlass/cpp/test/conv/host.h index 142c468848..69a98390a6 100644 --- a/python/cutlass/cpp/test/conv/host.h +++ b/python/cutlass/cpp/test/conv/host.h @@ -155,7 +155,8 @@ void bind_conv_host_references(py::module &m) { /// Cache py::class_(m, "CachedTestKey") .def(py::init<>()) - .def(py::init()); + .def(py::init()) + .def_readwrite("problem", &test::conv::device::CachedTestKey::problem); py::class_(m, "CachedTestResult") .def(py::init<>()) diff --git a/python/cutlass/emit/common.py b/python/cutlass/emit/common.py index 4d1dd4cd4b..a8deb85616 100644 --- a/python/cutlass/emit/common.py +++ b/python/cutlass/emit/common.py @@ -117,16 +117,16 @@ typename DeviceKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K, L}, // problem size - A, // ptrA - make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A - B, // ptrB - make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B + {M, N, K, L}, // problem size + A, // ptrA + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A + B, // ptrB + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B { - C, // ptrC - make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C - D, // ptrD - make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D + C, // ptrC + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C + D, // ptrD + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D {alpha, beta}, }, hw_info @@ -180,3 +180,86 @@ return status; } """ + + +_CUTLASS_KERNEL_RUN_CONV2D_2x = """ + +using UnderlyingKernel = typename DeviceKernel::UnderlyingKernel; +namespace { +using TensorRefA = typename UnderlyingKernel::TensorRefA; +using TensorRefB = typename UnderlyingKernel::TensorRefB; +using TensorRefC = typename UnderlyingKernel::TensorRefC; +using ElementCompute = typename UnderlyingKernel::EpilogueOutputOp::ElementCompute; +} + +template +TensorRef get_tensor_ref(cutlass::Tensor4DCoord tensor_coord, Element* ptr){ + cutlass::layout::TensorNHWC layout = cutlass::layout::TensorNHWC::packed(tensor_coord); + TensorRef tensor_ref(ptr, layout); + return tensor_ref; +} + +cutlass::Status ${name}_kernel_run(cutlass::conv::Conv2dProblemSize* problem_size, + UnderlyingKernel::ElementA* A, UnderlyingKernel::ElementB* B, + UnderlyingKernel::ElementC* C, UnderlyingKernel::ElementC* D, + ElementCompute alpha, ElementCompute beta, std::string split_k_mode, + cudaStream_t stream, int device_id=0) { + // create the tensor references + cutlass::Tensor4DCoord tensor_coord_A = cutlass::conv::implicit_gemm_tensor_a_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + cutlass::Tensor4DCoord tensor_coord_B = cutlass::conv::implicit_gemm_tensor_b_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + cutlass::Tensor4DCoord tensor_coord_C = cutlass::conv::implicit_gemm_tensor_c_extent( + cutlass::conv::Operator::k${conv_kind_name}, *problem_size + ); + + TensorRefA tensor_ref_A = get_tensor_ref(tensor_coord_A, A); + TensorRefB tensor_ref_B = get_tensor_ref(tensor_coord_B, B); + TensorRefC tensor_ref_C = get_tensor_ref(tensor_coord_C, C); + TensorRefC tensor_ref_D = get_tensor_ref(tensor_coord_C, D); + + cutlass::conv::SplitKMode mode; + if (split_k_mode == "serial") { + mode = cutlass::conv::SplitKMode::kSerial; + } else if (split_k_mode == "parallel") { + mode = cutlass::conv::SplitKMode::kParallel; + } else { + throw std::runtime_error("Invalid split_k_mode: " + split_k_mode); + } + + typename DeviceKernel::Arguments arguments{ + *problem_size, + tensor_ref_A, + tensor_ref_B, + tensor_ref_C, + tensor_ref_D, + {alpha, beta}, + mode + }; + + DeviceKernel implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + void* workspace_ptr = device_memory_allocation(workspace_size, device_id); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return status; + } + + status = implicit_gemm_op.initialize(arguments, workspace_ptr, stream); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // + // Launch initialized CUTLASS kernel + // + status = implicit_gemm_op(stream); + + return status; +} +""" diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index 61cc5d94db..1beedd0606 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -85,7 +85,8 @@ from cutlass import CUTLASS_PATH, logger, swizzle from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal -from cutlass.backend.library import ApiVersion +from cutlass.backend.conv2d_operation import Conv2dOperation +from cutlass.backend.library import ApiVersion, ConvKindNames from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate from cutlass.emit import common @@ -95,12 +96,26 @@ _PYTORCH_CUDA_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include #include #include - +#include #include "cutlass/cutlass.h" #include "cutlass/util/device_memory.h" +// helper function allocating the memory +void* device_memory_allocation(size_t size, int device_id=0) { + if (size > 0) { + torch::Device device(torch::kCUDA, device_id); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::TensorOptions options = torch::TensorOptions().dtype(torch::kI8).device(device); + at::Tensor device_tensor = torch::empty({(long)size,}, options); + return reinterpret_cast(device_tensor.data_ptr()); + } else { + return nullptr; + } +} + ${includes} ${declaration} ${impl} @@ -143,6 +158,72 @@ } """ +_PYTORCH_CONV2D_FPROP_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel( + const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1); + +// C++ interface +at::Tensor ${name}( + const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + return ${name}_kernel(A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", + py::overload_cast< + const at::Tensor&, const at::Tensor&, at::optional, + std::tuple, std::tuple, std::tuple, float, float, std::string, int>( + &${name}), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, + py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), + py::arg("alpha") = 1.f, py::arg("beta") = 0.f, + py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); +} +""" + +_PYTORCH_CONV2D_GRAD_CPP_TEMPLATE = common._CSTYLE_AUTOGEN_COMMENT + """ +#include +#include +#include + +// CUDA forward declarations +at::Tensor ${name}_kernel( + std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1); + +// C++ interface +at::Tensor ${name}( + std::tuple result_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + return ${name}_kernel(result_size, A, B, C, stride, padding, dilation, alpha, beta, split_k_mode, split_k_slices); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run", + py::overload_cast< + std::tuple, const at::Tensor&, const at::Tensor&, at::optional, + std::tuple, std::tuple, std::tuple, float, float, std::string, int>( + &${name}), py::arg("result_size"), py::arg("A"), py::arg("B"), py::arg("C") = nullptr, + py::arg("stride") = std::make_tuple(1, 1), py::arg("padding") = std::make_tuple(1, 1), py::arg("dilation") = std::make_tuple(1, 1), + py::arg("alpha") = 1.f, py::arg("beta") = 0.f, + py::arg("split_k_mode") = "serial", py::arg("split_k_slices") = 1); +} +""" + _PYTORCH_GEMM_INCLUDES = { ApiVersion.v2x: """ #include "cutlass/gemm/device/gemm_universal.h" @@ -162,6 +243,13 @@ #include "cutlass/gemm/device/gemm_grouped.h" """ +_PYTORCH_CONV2D_INCLUDES = """ +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/kernel/default_conv2d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +""" + _CUTLASS_TYPE_TO_TORCH_TYPE = { cutlass_bindings.float16: "torch::kF16", cutlass_bindings.float32: "torch::kF32", @@ -356,6 +444,133 @@ """ ) +_PYTORCH_CONV2D_IMPL_TEMPLATE_2x = """ + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + cutlass::Status status = ${name}_kernel_run( + &problem_size, + reinterpret_cast(A.data_ptr()), + reinterpret_cast(B.data_ptr()), + ptrC, + reinterpret_cast(D.data_ptr()), + alpha, beta, + split_k_mode, stream, B.device().index()); + + TORCH_CHECK(status == cutlass::Status::kSuccess, "CUTLASS kernel failed"); + return D; +} +""" + +_PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, + float alpha=1.f, float beta=0.f, std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S, P, Q; + N = A.size(0); + C_ = A.size(1); + H = A.size(2); + W = A.size(3); + + K = B.size(0); + R = B.size(2); + S = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + P = problem_size.P; + Q = problem_size.Q; + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::zeros({N, K, P, Q}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(std::tuple input_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S; + N = std::get<0>(input_size); + C_ = std::get<1>(input_size); + H = std::get<2>(input_size); + W = std::get<3>(input_size); + + K = B.size(0); + R = B.size(2); + S = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::empty({N, C_, H, W}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + + +_PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x = ( + common._CUTLASS_KERNEL_RUN_CONV2D_2x + + """ +at::Tensor ${name}_kernel(std::tuple weight_size, const at::Tensor& A, const at::Tensor& B, at::optional C=at::nullopt, + std::tuple stride={1, 1}, std::tuple padding={0, 0}, std::tuple dilation={1, 1}, float alpha=1.f, float beta=0.f, + std::string split_k_mode="serial", int split_k_slices=1) { + int N, H, W, C_, K, R, S; + K = std::get<0>(weight_size); + C_ = std::get<1>(weight_size); + R = std::get<2>(weight_size); + S = std::get<3>(weight_size); + + N = B.size(0); + H = B.size(2); + W = B.size(3); + + cutlass::conv::Conv2dProblemSize problem_size( + cutlass::Tensor4DCoord(N, H, W, C_), + cutlass::Tensor4DCoord(K, R, S, C_), + cutlass::Tensor4DCoord(std::get<0>(padding), std::get<0>(padding), std::get<1>(padding), std::get<1>(padding)), + cutlass::MatrixCoord(std::get<0>(stride), std::get<1>(stride)), + cutlass::MatrixCoord(std::get<0>(dilation), std::get<1>(dilation)), + cutlass::conv::Mode::kCrossCorrelation, + split_k_slices + ); + + typename UnderlyingKernel::ElementC* ptrC = (C == at::nullopt) ? + nullptr : + reinterpret_cast(C->data_ptr()); + + torch::TensorOptions options = torch::TensorOptions().dtype(${torch_type_C}).device(B.device()).memory_format(at::MemoryFormat::ChannelsLast); + at::Tensor D = torch::empty({K, C_, R, S}, options); +""" + _PYTORCH_CONV2D_IMPL_TEMPLATE_2x +) + _PYTORCH_SETUP_PY = common._PYSTYLE_AUTOGEN_COMMENT + """ from setuptools import setup @@ -607,6 +822,73 @@ def _pytorch_grouped_gemm( return None +def _pytorch_conv2d(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): + """ + Generates source for building a PyTorch CUDA module that leverages the CUTLASS Conv2d + specified by ``op``. If the ``jit`` parameter is set to true, the module is just-in-time + compiled, loaded, and returned. + + :param op: operation to emit in the module + :param name: name of the module to generate + :type name: str + :param cc: compute capability of the device the module should target + :type cc: int + :param jit: whether the module should be just-in-time compiled + :type jit: bool + :param sourcedir: directory to which generated source files should be written + :type sourcedir: str + + Note that the when conv kind is `dgrad` or `wgrad`, the size of the input `(N, C, H, W)` or + weight `(K, C, R, S)` should be provided. This is because there are multiple valid solutions + for H/W/R/S given the same P/Q. + + :return: loaded PyTorch module if ``jit=True`` or ``None`` otherwise + """ + if sourcedir != "" and not os.path.isdir(sourcedir): + os.makedirs(sourcedir) + cuda_file = os.path.join(sourcedir, name + "_kernel.cu") + extra_kw = {} + if op.conv_kind == cutlass_bindings.conv.Operator.fprop: + impl_template = _PYTORCH_CONV2D_FPROP_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_FPROP_CPP_TEMPLATE + elif op.conv_kind == cutlass_bindings.conv.Operator.dgrad: + impl_template = _PYTORCH_CONV2D_DGRAD_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE + elif op.conv_kind == cutlass_bindings.conv.Operator.wgrad: + impl_template = _PYTORCH_CONV2D_WGRAD_IMPL_TEMPLATE_2x + cpp_template = _PYTORCH_CONV2D_GRAD_CPP_TEMPLATE + extra_kw["conv_kind_name"] = ConvKindNames[op.conv_kind].capitalize() + extra_kw["torch_type_C"] = _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element] + cuda_impl = SubstituteTemplate(impl_template, {"name": name, **extra_kw}) + cuda_source = SubstituteTemplate( + _PYTORCH_CUDA_TEMPLATE, + { + "includes": _PYTORCH_CONV2D_INCLUDES, + "declaration": op.rt_module.emit(), + "procedural_name": op.procedural_name(), + "impl": cuda_impl, + "torch_type_C": _CUTLASS_TYPE_TO_TORCH_TYPE[op.C.element], + }, + ) + with open(cuda_file, "w") as outfile: + outfile.write(cuda_source) + + cpp_file = os.path.join(sourcedir, name + ".cpp") + cpp_source = SubstituteTemplate( + cpp_template, + {"name": name, "description": f"CUTLASS {op.procedural_name()} Conv2d"}, + ) + with open(cpp_file, "w") as outfile: + outfile.write(cpp_source) + + _generate_setup(name, sourcedir) + + if jit: + return _jit(name, cc, cpp_file, cuda_file) + + return None + + def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): """ Generates source for building a PyTorch CUDA module that leverages the CUTLASS kernel @@ -633,6 +915,8 @@ def pytorch(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""): return _pytorch_gemm(device_op, name, cc, jit, sourcedir) elif isinstance(op, GemmOperationGrouped): return _pytorch_grouped_gemm(device_op, name, cc, jit, sourcedir) + elif isinstance(op, Conv2dOperation): + return _pytorch_conv2d(device_op, name, cc, jit, sourcedir) else: raise Exception( f"Operation type {type(op)} is not currently supported for PyTorch emission." diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 997d5d3589..9e70d47489 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -43,6 +43,9 @@ # Imports from CUTLASS profiler generator and manifest scripts import generator as prof_generator import manifest as prof_manifest +from library import ( + ConvKind, IteratorAlgorithm, StrideSupport, GroupMode +) import cutlass from cutlass.utils.check import valid_stage_count @@ -132,6 +135,8 @@ def find_alignment(self, shape: tuple, layout: cutlass.LayoutType) -> int: ld = shape[0] elif layout == cutlass.LayoutType.RowMajor: ld = shape[1] + elif layout == cutlass.LayoutType.TensorNHWC: + ld = shape[-1] else: raise Exception(f"Unexpected or unsupported layout {layout}") @@ -222,8 +227,9 @@ def __init__( # find available opclasses and data types for name, op_list in manifest.operations[operation_kind].items(): for op in op_list: - if op.gemm_kind not in gemm_kinds: - continue + if operation_kind == cutlass.OperationKind.Gemm: + if op.gemm_kind not in gemm_kinds: + continue mi = op.tile_description.math_instruction if mi.math_operation not in self.allowed_math_operations: @@ -276,21 +282,36 @@ def __init__( if cutlass.OpcodeClass.Simt not in self.operations_by_opclass: self.operations_by_opclass[cutlass.OpcodeClass.Simt] = {} - types = [ - (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8), - (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32), - (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16), - (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32), - (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32), - (cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64), - ] + if operation_kind == cutlass.OperationKind.Gemm: + types = [ + (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8), + (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32), + (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16), + (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32), + (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32), + (cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64), + ] + + layouts = [ + (cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor), + (cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor), + (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor), + (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor), + ] + elif operation_kind == cutlass.OperationKind.Conv2d: + types = [ + (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16), + (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32), + (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32), + (cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64), + ] + + layouts = [ + (cutlass.LayoutType.TensorNHWC, cutlass.LayoutType.TensorNHWC), + ] + else: + raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.") - layouts = [ - (cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor), - (cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor), - (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor), - (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor), - ] alignment = 1 epilogue_functor = cutlass.EpilogueFunctor.LinearCombination swizzling_functor = cutlass.SwizzlingFunctor.Identity8 @@ -319,12 +340,22 @@ def __init__( if not valid_stage_count(target_cc, td_from_profiler_td(td))[0]: continue - new_operation = prof_manifest.GemmOperation( - cutlass.GemmKind.Universal, td.minimum_compute_capability, - td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor) - new_kernels = KernelsForDataType(type_comb, layout_comb) - new_kernels.add(new_operation) + + if operation_kind == cutlass.OperationKind.Gemm: + new_operation = prof_manifest.GemmOperation( + cutlass.GemmKind.Universal, td.minimum_compute_capability, + td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor) + new_kernels.add(new_operation) + elif operation_kind == cutlass.OperationKind.Conv2d: + for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: + new_operation = prof_manifest.Conv2dOperation( + conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td, + A, B, C, type_comb[2], StrideSupport.Strided, epilogue_functor, swizzling_functor, + group_mode=GroupMode.SingleGroup + ) + new_kernels.add(new_operation) + self.operations_by_opclass[cutlass.OpcodeClass.Simt][comb] = new_kernels # Sort all operations @@ -437,9 +468,12 @@ def __init__(self, target_cc: int): self.registry = {} gemm_kinds = [cutlass.GemmKind.Universal, cutlass.GemmKind.Universal3x] + operation_kinds = [cutlass.OperationKind.Gemm, cutlass.OperationKind.Conv2d] # Construct options for each CC for kernel_cc in _generator_ccs: - self.registry[kernel_cc] = ArchOptions(target_cc, kernel_cc, cutlass.OperationKind.Gemm, gemm_kinds) + self.registry[kernel_cc] = {} + for opkind in operation_kinds: + self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds) - def options_for_cc(self, cc: int) -> ArchOptions: - return self.registry.get(cc, None) + def options_for_cc(self, cc: int, op_kind=cutlass.OperationKind.Gemm) -> ArchOptions: + return self.registry.get(cc, None)[op_kind] diff --git a/python/cutlass/op/__init__.py b/python/cutlass/op/__init__.py index 59b02a36d7..d3cfbe7e22 100644 --- a/python/cutlass/op/__init__.py +++ b/python/cutlass/op/__init__.py @@ -31,5 +31,6 @@ ################################################################################################# from cutlass.op.gemm import Gemm +from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad from cutlass.op.gemm_grouped import GroupedGemm from cutlass.op.op import OperationBase diff --git a/python/cutlass/op/conv.py b/python/cutlass/op/conv.py new file mode 100644 index 0000000000..32c5a8e778 --- /dev/null +++ b/python/cutlass/op/conv.py @@ -0,0 +1,960 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" + Ease-of-use interface for constructing, compiling, and running CONVs + + The ``Conv2d`` interface is meant to allow one to easily instantiate, compile, and run + CONV2D operations in CUTLASS via Python, without specifying many configuration parameters. + Under the hood, the interface will select sensible default parameters for the many template + parameters for CUTLASS CONVs. + + Note: optimal performance is not to be expected from this interface. To achieve optimal + performance, one should specify and tune each configuration parameter. + + The simplest example of using this interface is the following: + + .. highlight:: python + .. code-block:: python + + # A, B, C, and D are torch/numpy/cupy tensor objects + plan = cutlass.op.Conv(A, B, C, D) + plan.run(stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + One can also use the interface by specifying data types of operands at construction + and using different tensor objects with these data types at runtime: + + .. highlight:: python + .. code-block:: python + + # The following is shorthand for: + # cutlass.op.Conv2d(kind="fprop", + # element_A=torch.float32, element_B=torch.float32, + # element_C=torch.float32, element_D=torch.float32, + # element_accumulator=torch.float32) + plan = cutlass.op.Conv2d(kind="fprop", element=torch.float32) + + A0 = torch.rand((128, 256), dtype=torch.float32, device='cuda') + B0 = torch.rand((256, 64), dtype=torch.float32, device='cuda') + C0 = torch.zeros((128, 64), dtype=torch.float32, device='cuda') + D0 = torch.zeros((128, 64), dtype=torch.float32, device.'cuda') + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + A = torch.rand((32, 128), dtype=torch.float32, device='cuda') + B = torch.rand((128, 256), dtype=torch.float32, device='cuda') + C = torch.zeros((32, 256), dtype=torch.float32, device='cuda') + D = torch.zeros((32, 256), dtype=torch.float32, device.'cuda') + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + The interface additionally enables one to decouple the compilation of the underlying CUTLASS + kernel from its execution: + + .. highlight:: python + .. code-block:: python + + plan = cutlass.op.Conv2d(kind="fprop", element=np.float32) + + # Do other work... + + plan.run(A0, B0, C0, D0, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + # Do other work... + + plan.run(A1, B1, C1, D1, stride=(1, 1), padding=(0, 0), dilation=(1, 1)) + + Elementwise activation functions are easily fused to the GEMM via the interface: + + .. highlight:: python + .. code-block:: python + + plan = cutlass.op.Conv2d(kind="fprop", element=np.float32) + plan.activation = cutlass.epilogue.relu + + Operations can also be run asynchronously: + + .. highlight:: python + .. code-block:: python + + plan = cutlass.op.Conv2d(kind="fprop", element=np.float32) + args = plan.run() + + # Do other work... + + args.sync() +""" + +import cutlass_bindings +import cutlass +from cutlass import epilogue +from cutlass.backend import compiler +from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation +from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments +from cutlass.backend.library import TensorDescription, TileDescription +from cutlass.op.op import OperationBase +from cutlass.utils import check, datatypes + +class Conv2d(OperationBase): + """ + Constructs a ``Conv2d`` object. + + The convolution kind (fprop, wgrad, degrad), the data types of operands A, B, and C, + along with the data type of output D and that used for accumulation, are bound to the ``Conv`` + object throughout its lifetime -- these are not to be changed after a ``Conv2d`` has been constructed. + + The constructor has optional parameters for flexibly setting these parameters. The following + constructors are equivalent: + + .. highlight:: python + .. code-block:: python + + # Use F32 for A, B, C, D, and accumulation in fprop + + # Use the generic ``element`` parameter to concisely set all data types for operands to the same values. + Conv2d(kind="fprop", element=cutlass.DataType.f32) + + # Explicitly specify the data types to use for A, B, C, and D. + Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_B=cutlass.DataType.f32, + element_C=cutlass.DataType.f32, element_D=cutlass.DataType.f32) + + # Set the data types and elements from existing tensors. Note that one can use different tensors when + # executing GEMM via the ``run()`` method than passed in here (though those passed in to ``run()`` must + # have the same data type as those passed in here). + # A, B, C, and D are torch.Tensor objects of type torch.float32 under the channel-last layout + Conv2d(kind="fprop", A=A, B=B, C=C, D=D) + + # Explicitly specify the data type for only some of A, B, C, and D. Unspecified data types will inherit + # those passed in via the generic ``element`` + Conv2d(kind="fprop", element_A=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, + element=cutlass.DataType.f32) + + The order of precedence for the setting of the data type for a given operand/output is as follows: + 1) If the tensor type is specified (e.g., ``A``), use the data type inferred from this tensor + 2) Otherwise, if the data type (e.g., ``element_A``) is specified, use those + 3) Otherwise, use the generic values (e.g., ``element``) + + :param kind: the convolution kind (i.e. fprop, wgrad, and dgrad) + :type kind: str + :param A: tensor representing data type of operand A + :param B: tensor representing data type of operand B + :param C: tensor representing data type of operand C + :param D: tensor representing data type of operand D + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type + :type element: cutlass.DataType + :param element_A: data type to be used for operand A + :type element_A: cutlass.DataType + :param element_B: data type to be used for operand B + :type element_B: cutlass.DataType + :param element_C: data type to be used for operand C + :type element_C: cutlass.DataType + :param element_D: data type to be used for operand D + :type element_D: cutlass.DataType + :param element_accumulator: data type to be used in accumulation of the product of operands A and B + :type element_accumulator: cutlass.DataType + :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 + :type cc: int + :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 + :type kernel_cc: int + """ + def __init__( + self, kind="fprop", + A=None, B=None, C=None, D=None, alpha=1.0, beta=0.0, + element=None, + element_A=None, element_B=None, element_C=None, element_D=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None + ): + super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=cutlass.OperationKind.Conv2d) + # Verify the kernel cc + if self.current_cc == 90: + # The Conv2d kernel on Hopper (SM90) is currently unsupported + # Revert to use SM80-tagged kernels + cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self.specified_kernel_cc = 80 + self._reset_options(80) + + # The arch is used in testing + self.arch = self.current_cc + self.name = "conv2d" + kind + + # The convolution kind. (concept: cutlass_bindings.conv.Operator) + self.conv_kind = getattr(cutlass_bindings.conv.Operator, kind) + + # The element types (concept: cutlass library types) of A, B, C, and D + elements = [] + layouts = [] + + # Complete the data types based on user-provided arguments + for elt, tens, name in zip([element_A, element_B, element_C, element_D], + [A, B, C, D], + ["A", "B", "C", "D"]): + if elt is not None and tens is not None: + raise Exception(f'Must not specify both element_{name} and tensor {name}') + if elt is None and tens is None and element is None: + raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.') + + elt_to_set = None + lay_to_set = None + + if tens is not None: + elt_to_set, _ = datatypes.get_datatype_and_layout(tens) + else: + elt_to_set = elt if elt is not None else element + + assert elt_to_set is not None + + # Currently we only support layout TensorNHWC + lay_to_set = cutlass.LayoutType.TensorNHWC + elements.append(datatypes.library_type(elt_to_set)) + layouts.append(lay_to_set) + + self._element_a, self._element_b, self._element_c, self._element_d = elements + self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts + + self.A, self.B, self.C, self.D, self.alpha, self.beta = A, B, C, D, alpha, beta + + if element_accumulator is None: + self._element_accumulator = self._element_c + else: + self._element_accumulator = datatypes.library_type(element_accumulator) + + # Default inputs if none is supplied in run() + self.A = A + self.B = B + self.C = C + self.D = D + + self.alpha = alpha + self.beta = beta + + # We only specify the stride of the swizzling functor here + # The actual swizzling functor is determined in run based on conv_kind and stride + self._swizzling_stride = 1 + + # Arguments that will be set to default value in _reset_operations + # The default tile_description and op_class are fetched from manifest of cutlass library + self._tile_description = None + self.op_class = None + # The default identity epilogue will be created + self.epilogue_functor = None + + self._reset_operations() + + # Arguments that will be determined online based on arguments of "run" + # based on stride, input/output channels, alignment, and conv_kind + self._iterator_algorithm = None + self._stride_support = None + + def _reset_operations(self, reset_epilogue: bool = True): + # Set the default op class + datatype_comb = (self._element_a, self._element_b, self._element_accumulator) + layout_comb = (self._layout_a, self._layout_b) + + self.possible_op_classes = self.options.supporting_opclasses( + self._element_a, self._element_b, self._element_accumulator, + self._layout_a, self._layout_b + ) + + if cutlass.OpcodeClass.TensorOp in self.possible_op_classes: + self.opclass = cutlass.OpcodeClass.TensorOp + elif cutlass.OpcodeClass.Simt in self.possible_op_classes: + self.opclass = cutlass.OpcodeClass.Simt + else: + raise Exception(f'No kernel configuration found for supported data type and layout ' + f'combination {datatype_comb}x{layout_comb}') + + if reset_epilogue: + self._reset_epilogue_functor_activation(epilogue.identity) + + self.alignment_pref_A = min( + 128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments)) + self.alignment_pref_B = min( + 128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments)) + self.alignment_pref_C = min( + 128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments)) + + # + # Tile description Related + # + + @property + def tile_description(self) -> TileDescription: + """ + Returns the tile description + """ + return self._tile_description + + @tile_description.setter + def tile_description( + self, td=None): + """ + Set the tile description + + :param td: tile description + :type td: cutlass.backend.TileDescription, or a dict with keys + { + "threadblock_shape": [int, int, int], + "warp_count": [int, int, int], + "stages": int, + "instruction_shape": [int, int, int] (optional), + "cluster_shape": [int, int, int] (optional) + } + """ + if td is None: + return + if isinstance(td, dict): + if self._tile_description is None: + alignment = list(self.possible_operations.kernels_by_alignment.keys())[0] + op = self.possible_operations.operations(alignment)[0] + self._tile_description = datatypes.td_from_profiler_op(op) + if "cluster_shape" in td.keys(): + if td["cluster_shape"] != [1, 1, 1]: + cutlass.logger.warning("Conv2d currently only support 'cluster_shape'=[1, 1, 1]'.") + td["cluster_shape"] = [1, 1, 1] + td = self._tile_description.clone_and_update(td) + + valid, msg = self._valid_tile_description(td) + if valid: + self._tile_description = td + else: + raise Exception(msg) + + def _valid_tile_description(self, td: TileDescription) -> tuple: + """ + Checks whether the provided tile description is valid for the given compute capability. At present, + this checks the following: + + - Does the tile description use a number of stages supported by the compute capability in question? + - Does the tile size requested fit within shared memory? + - Are cluster dimensions outside the valid range requested for a given architecture (e.g., + more non-unit cluster dimensions for pre-SM90 architectures)? + - Is the kernel schedule being used supported on the architecture in question? + + :param td: tile description to validate + :type td: cutlass.backend.TileDescription + :return: tuple in which the first element is a bool indicating that the tile description is valid + and the second element is a string providing an optional error message. + :rtype: tuple + """ + # Check stage count based on the CC to which we are compiling (self.cc), rather + # than the CC from which we find kernels (self.current_cc) + valid, msg = check.valid_stage_count(self.cc, td) + if not valid: + return (valid, msg) + + valid, msg = check.valid_cluster_shape(self.current_cc, td.cluster_shape) + if not valid: + return (valid, msg) + + return valid, msg + + def tile_descriptions(self) -> list: + """ + Returns a list of valid tile descriptions for the operations + + :returns: list of valid tile descriptions for the operations + :rtype: list + """ + descriptions = [] + description_str = [] + for op in self.possible_operations.all_operations: + td = datatypes.td_from_profiler_op(op) + if str(td) not in description_str: + description_str.append(str(td)) + descriptions.append(td) + return descriptions + + # + # Swizzling functor Related + # + + @property + def swizzling_stride(self): + """ + Returns the stride of swizzling currently being used by the Conv2d + + :return: swizzing stride + """ + return self._swizzling_stride + + @swizzling_stride.setter + def swizzling_stride(self, stride: int): + """ + Sets the swizzling functor to the type specified by `swizzling_functor` + """ + if not isinstance(stride, int): + raise Exception(f"Expect integer (1, 2, 4, 8), got {stride}") + self._swizzling_stride = stride + + def _propose_swizzling_functor(self, stride): + """ + Automatically propose the swizzling functor based on the stride + """ + if self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + if stride[0] != 1 or stride[1] != 1: + return getattr(cutlass.swizzle, f"StridedDgradIdentitySwizzle{self._swizzling_stride}") + + return getattr(cutlass.swizzle, f"IdentitySwizzle{self._swizzling_stride}") + + # + # Iterator Algorithm Related + # + + @property + def iterator_algorithm(self) -> cutlass_bindings.conv.IteratorAlgorithm: + """ + Returns the iterator algorithm + """ + return self._iterator_algorithm + + @iterator_algorithm.setter + def iterator_algorithm(self, alg: str): + """ + Sets the iterator algorithm + + :param alg: The iterator algorithm + :type td: string, options: "analytic", "optimized", "few_channels", and "fixed_channels" + """ + # Check if the iterator algorithm is valid + if alg in ["few_channels", "fixed_channels"] and self.conv_kind != cutlass_bindings.conv.Operator.fprop: + raise Exception(f"{self.conv_kind} does not support iterator algorithm {alg}.") + + self._iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, alg) + + def _propose_iterator_algorithm(self, problem_size, alignment_a, alignment_b) -> cutlass_bindings.conv.IteratorAlgorithm: + """ + Propose a valid iterator algorithm based on problem size and alignment + """ + if self.conv_kind == cutlass_bindings.conv.Operator.fprop: + # Check whether the fixed channel is applicable + if problem_size.C == alignment_a: + return cutlass_bindings.conv.IteratorAlgorithm.fixed_channels + elif (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32): + return cutlass_bindings.conv.IteratorAlgorithm.optimized + else: + return cutlass_bindings.conv.IteratorAlgorithm.analytic + elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0): + return cutlass_bindings.conv.IteratorAlgorithm.optimized + else: + return cutlass_bindings.conv.IteratorAlgorithm.analytic + elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad: + if (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0): + return cutlass_bindings.conv.IteratorAlgorithm.optimized + else: + return cutlass_bindings.conv.IteratorAlgorithm.analytic + + def _validate_iterator_algorithm(self, iterator_algorithm, problem_size, alignment_a, alignment_b) -> bool: + """ + Validate whether the user provide iterator algorithm works for the given problem size + """ + if self.conv_kind == cutlass_bindings.conv.Operator.fprop: + if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.fixed_channels: + return problem_size.C == alignment_a + elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized: + return (problem_size.C % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32) + elif iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.few_channels: + return problem_size.C % alignment_a == 0 + elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.R <= 32 and problem_size.S <= 32 and + problem_size.C % alignment_b == 0) + elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad: + if iterator_algorithm == cutlass_bindings.conv.IteratorAlgorithm.optimized: + return (problem_size.K % alignment_a == 0 and + problem_size.C % alignment_b == 0) + + return True + + # + # Stride Support Related + # + + def _propose_stride_support(self, stride): + if self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + if stride[0] == 1 and stride[1] == 1: + return cutlass.backend.library.StrideSupport.Unity + + return cutlass.backend.library.StrideSupport.Strided + + # + # Construct and Compilation + # + + def construct( + self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass.swizzle = None, + epilogue_functor=None) -> cutlass.backend.Conv2dOperation: + """ + Constructs a ``cutlass.backend.Conv2dOperation`` based on the input parameters and current + kernel specification of the ``Conv2d`` object. + + :param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass.backend.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was constructed + :rtype: cutlass.backend.Conv2dOperation + """ + # Get alignment + alignment_A = check.alignment_or_default(alignment_A, self.alignment_pref_A) + alignment_B = check.alignment_or_default(alignment_B, self.alignment_pref_B) + alignment_C = check.alignment_or_default(alignment_C, self.alignment_pref_C) + + tensor_A = TensorDescription( + datatypes.binding_type(self._element_a), + datatypes.binding_layout(self._layout_b), + alignment_A + ) + tensor_B = TensorDescription( + datatypes.binding_type(self._element_b), + datatypes.binding_layout(self._layout_b), + alignment_B + ) + tensor_C = TensorDescription( + datatypes.binding_type(self._element_c), + datatypes.binding_layout(self._layout_c), + alignment_C + ) + + if tile_description is None: + if self.tile_description is not None: + tile_description = self.tile_description + else: + op = self.possible_operations.operations(alignment_A)[0] + tile_description = datatypes.td_from_profiler_op(op) + else: + valid, err_str = self._valid_tile_description(tile_description) + if not valid: + raise Exception(f"Invalid tile description. {err_str}") + self.tile_description = tile_description + + if iterator_algorithm is None: + # If the iterator algorithm is already set + if self.iterator_algorithm is not None: + iterator_algorithm = self.iterator_algorithm + else: + # Otherwise, we conservatively use the analytic iterator for correctness + iterator_algorithm = cutlass_bindings.conv.IteratorAlgorithm.analytic + + if stride_support is None: + # If the stride support is already set + if self._stride_support is not None: + stride_support = self._stride_support + else: + # Otherwise, we assume strided + stride_support = cutlass.backend.library.StrideSupport.Strided + + if swizzling_functor is None: + # If the swizzling functor is already set + swizzling_functor = self._propose_swizzling_functor(stride=(2, 2)) + + if epilogue_functor is None: + if self.epilogue_functor is not None: + epilogue_functor = self.epilogue_functor + else: + epilogue_functor = self._create_epilogue_functor_activation(self._activation) + + # Reset the alignment of the epilogue functor + epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, epilogue_functor) + + operation = Conv2dOperation( + conv_kind=self.conv_kind, + iterator_algorithm=iterator_algorithm, + arch=self.current_cc, + tile_description=tile_description, + A=tensor_A, B=tensor_B, C=tensor_C, + stride_support=stride_support, + epilogue_functor=epilogue_functor, + swizzling_functor=swizzling_functor, + ) + + return operation + + def compile(self, tile_description: TileDescription = None, + alignment_A: int = None, alignment_B: int = None, alignment_C: int = None, + iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm = None, + stride_support = None, swizzling_functor: cutlass.swizzle = None, + epilogue_functor = None, print_module: bool = False) -> cutlass.backend.Conv2dOperation: + """ + Emits and compiles the kernel currently specified. If ``tile_description`` and any + of the ``alignment`` parameters are set, the kernel will be chosen using this + tile description and alignments. Otherwise, a default tile description and alignment + will be used. + + ::param tile_description: tile description specifying shapes and operand types to use in the kernel + :type tile_description: cutlass.backend.TileDescription + :param alignment_A: alignment of operand A + :type alignment_A: int + :param alignment_B: alignment of operand B + :type alignment_B: int + :param alignment_C: alignment of operand C + :type alignment_C: int + :param iterator_algorithm: the iterator algorithm used + :type iterator_algorithm: cutlass_bindings.conv.IteratorAlgorithm + :param stride_support: the stride support of dgrad + :type stride_support: cutlass.backend.library.StrideSupport + :param swizzling_functor: the swizzling functor + :type swizzling_functor: cutlass.swizzle + :param epilogue_functor: the epilogue functor + + :return: operation that was compiled + :rtype: cutlass.backend.Conv2dOperation + """ + + self.operation = self.construct( + tile_description, alignment_A, alignment_B, alignment_C, + iterator_algorithm, stride_support, swizzling_functor, epilogue_functor) + + if print_module: + print(self.operation.rt_module.emit()) + + compiler.add_module([self.operation,]) + return self.operation + + # + # Run Related + # + + def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): + """ + Verifies that ``tensor`` has data type ``ref_type`` and layout ``ref_layout``. An exception + is raised if it does not. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + """ + dtype, _ = datatypes.get_datatype_and_layout(tensor) + if dtype != ref_type: + raise Exception(f'Tensor {name} with type and layout {dtype} ' + f'does not match the expected type of {ref_type}.') + + + + def _get_and_verify_conv_problem_size(self, A, B, C, stride, padding, dilation): + if self.conv_kind == cutlass_bindings.conv.Operator.fprop: + input = A + weight = B + output = C + output_tensor = "C" + elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + output = A + weight = B + input = C + output_tensor = "A" + elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad: + output = A + input = B + weight = C + output_tensor = "A" + else: + raise Exception(f"Convolution kind {self.conv_kind} is not supported") + + N_, H_, W_, C_ = datatypes.get_tensor_shape(input) + K_, R_, S_, _ = datatypes.get_tensor_shape(weight) + _, P_, Q_, _ = datatypes.get_tensor_shape(output) + + problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(N_, H_, W_, C_), + cutlass_bindings.Tensor4DCoord(K_, R_, S_, C_), + cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]), + cutlass_bindings.MatrixCoord(stride[0], stride[1]), + cutlass_bindings.MatrixCoord(dilation[0], dilation[1]), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ) + + if P_ != problem_size.P or Q_ != problem_size.Q: + raise Exception( + f"Tensor {output_tensor} size should be ({N_}, {problem_size.P}, {problem_size.Q}, {K_}), got ({N_}, {P_}, {Q_}, {K_})") + + return problem_size + + def run(self, A=None, B=None, C=None, D=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), + alpha=None, beta=None, + split_k=("serial", 1), sync: bool = True, + print_module: bool = False) -> Conv2dArguments: + """ + Runs the kernel currently specified. If it has not already been, the kernel is emitted and + compiled. Tensors holding operands and outputs of the kernel are sourced either from the + ``A``, ``B``, ``C``, ``D``, ``alpha``, and ``beta`` + parameters provided in the call, or from those + passed in on the construction of this object -- one of the two must be specified. + + By default, this call returns only once the kernel has completed. To launch the kernel + and immediately return, set ``sync=False``. In this case, it is the responsibility of the + caller to syncrhonize the results of the kernel before attempting to access outputs + by calling ``sync()`` on the arguments returned from this call. + + :param A: tensor representing data type and layout of operand A + :param B: tensor representing data type and layout of operand B + :param C: tensor representing data type and layout of operand C + :param D: tensor representing data type and layout of operand D + :param stride: (stride_h, stride_w) describing the convolution stride. Default: (1, 1) + :param padding: (pad_h, pad_w) describing the convolution padding. Default: (0, 0) + :param dilation: (dilation_h, dilation_w) describing the dilation of convolution. Default: (1, 1) + :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B + :param beta: scalar parameter beta from GEMM operation that scales operand C + :param split_k: a tuple (split_k_mode, split_k_slices) + :param sync: whether the call should wait for the kernel to complete before returning + :type sync: bool + :param print_module: whether to print the emitted C++ code + :type print_module: bool + + :return: arguments passed in to the kernel + :rtype: cutlass.backend.Conv2dArguments + """ + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") + B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") + C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") + D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D") + alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") + beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + + # handle the case when there is no C + if C is None: + if beta != 0: + raise Exception(f"With beta {beta} != 0, C has to be provided.") + else: + C = D + + # Construct problem size based on input + # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching + problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation) + + # Propose stride support based on input + stride_support = self._propose_stride_support(stride) + + # Propose swizzling functor + swizzling_functor = self._propose_swizzling_functor(stride) + + # Get the alignment + alignment_a = self.possible_operations.find_alignment(datatypes.get_tensor_shape(A), self._layout_a) + alignment_b = self.possible_operations.find_alignment(datatypes.get_tensor_shape(B), self._layout_b) + alignment_c = self.possible_operations.find_alignment(datatypes.get_tensor_shape(C), self._layout_c) + + alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A) + alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B) + alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C) + + # Propose iterator algorithm based on input + if self._iterator_algorithm is None: + # Propose a default itertaor algorithm based on the problem size + iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b) + else: + if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)): + iterator_algorithm = self._iterator_algorithm + else: + raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.") + + epilogue_args = [alpha, beta] + + if hasattr(self, "_activation_args"): + if isinstance(self._activation_args, list): + epilogue_args += self._activation_args + else: + epilogue_args.append(self._activation_args) + + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity) + else: + epilogue_functor = self.epilogue_functor + + # The alignment is determined by the iterator function (I believe) + self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, + alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support, + swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module) + + # Create reduction operation for parallel split-k + if split_k[0] == "parallel" and split_k[1] > 1: + epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor) + self.reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * alignment_c), C=self.operation.C, + element_accumulator=datatypes.binding_type(self._element_accumulator), + element_compute=datatypes.binding_type(self._element_accumulator), + epilogue_functor=epilogue_functor_reduction, + count=alignment_c + ) + if print_module: + print(self.reduction_operation.rt_module.emit()) + compiler.add_module([self.reduction_operation,]) + + arguments = Conv2dArguments( + operation=self.operation, problem_size=problem_size, + A=A, B=B, C=C, D=D, + output_op=self.operation.epilogue_type(*epilogue_args), + split_k_mode=datatypes.getattr_enum(cutlass_bindings.conv.SplitKMode, split_k[0]), + split_k_slices=split_k[1] + ) + + self.operation.run(arguments) + + if split_k[0] == "parallel" and split_k[1] > 1: + implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size( + self.conv_kind, arguments.problem_size + ) + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], + partitions=split_k[1], + workspace=arguments.ptr_D, + destination=D, + source=C, + output_op=self.reduction_operation.epilogue_type(*epilogue_args) + ) + self.reduction_operation.run(reduction_arguments) + + if sync: + if split_k[0] == "parallel" and split_k[1] > 1: + reduction_arguments.sync() + else: + arguments.sync() + + return arguments + + # + # Helper functions + # + @staticmethod + def output_size(input_size, weight_size, padding, stride, dilation): + problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(*input_size), + cutlass_bindings.Tensor4DCoord(*weight_size), + cutlass_bindings.Tensor4DCoord(padding[0], padding[0], padding[1], padding[1]), + cutlass_bindings.MatrixCoord(stride[0], stride[1]), + cutlass_bindings.MatrixCoord(dilation[0], dilation[1]), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ) + return (problem_size.N, problem_size.P, problem_size.Q, problem_size.K) + + +# +# Easy to use interfaces for fprop, wgrad, and dgrad +# + +class Conv2dFprop(Conv2d): + def __init__( + self, + input=None, weight=None, C=None, output=None, alpha=1, beta=0, + element=None, + element_input=None, element_weight=None, element_C=None, element_output=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = input, weight, output + element_A, element_B, element_D = element_input, element_weight, element_output + super().__init__( + "fprop", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run( + self, input=None, weight=None, C=None, output=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False) -> Conv2dArguments: + + A, B, D = input, weight, output + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module) + + +class Conv2dDgrad(Conv2d): + def __init__( + self, + grad_output=None, weight=None, C=None, grad_input=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_weight=None, element_C=None, element_grad_input=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, weight, grad_input + element_A, element_B, element_D = element_grad_output, element_weight, element_grad_input + super().__init__( + "dgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, weight=None, C=None, grad_input=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False) -> Conv2dArguments: + # + A, B, D = grad_output, weight, grad_input + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module) + + +class Conv2dWgrad(Conv2d): + def __init__( + self, + grad_output=None, input=None, C=None, grad_weight=None, alpha=1, beta=0, + element=None, + element_grad_output=None, element_input=None, element_C=None, element_grad_weight=None, + element_accumulator=None, + cc: int = None, kernel_cc: int = None): + A, B, D = grad_output, input, grad_weight + element_A, element_B, element_D = element_grad_output, element_input, element_grad_weight + super().__init__( + "wgrad", A, B, C, D, alpha, beta, element, + element_A, element_B, element_C, element_D, + element_accumulator, cc, kernel_cc) + + def run(self, grad_output=None, input=None, C=None, grad_weight=None, alpha=None, beta=None, + stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1), + sync: bool = True, print_module: bool = False) -> Conv2dArguments: + # + A, B, D = grad_output, input, grad_weight + return super().run( + A, B, C, D, alpha, beta, stride, padding, dilation, split_k, sync, print_module) diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index e33843ae52..67d1f14ede 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -287,108 +287,6 @@ def _reset_operations(self, reset_epilogue: bool = True): if reset_epilogue: self._reset_epilogue_functor_activation(epilogue.identity) - def _reset_epilogue_functor_activation(self, activation): - if self.epilogue_functor is None: - if self.op_class == cutlass.OpcodeClass.Simt: - elements_per_access = 1 - else: - elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] - else: - elements_per_access = self.epilogue_functor.epilogue_vector_length - - if not self.specified_kernel_cc: - if self.current_cc == 90 and activation != epilogue.identity: - # CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation, - # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. - cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") - self._reset_options(80) - self._reset_operations(reset_epilogue=False) - elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity): - # SM80 fallback kernels are currently used. Since an identity activation is requested, - # we can switch back to using SM90 kernels. - self._reset_options(90) - self._reset_operations(reset_epilogue=False) - else: - if self.current_cc == 90 and activation != epilogue.identity: - raise Exception("Epilogues with elementwise fusion are not currently supported " - "in the Python interface for 3.x kernels. To use 2.x kernels " - "with fused elementwise epilogues, do not set the `kernel_cc` " - "parameter when constructing the Gemm object.") - - self.epilogue_functor = epilogue.get_activation_epilogue( - activation, - datatypes.binding_type(self._element_c), - elements_per_access, - datatypes.binding_type(self._element_accumulator), - datatypes.binding_type(self._element_accumulator), - ) - - def _reset_epilogue_functor_alignment(self, alignment): - if self.epilogue_functor is None or not hasattr(self.epilogue_functor, 'activation_functor'): - activation = epilogue.identity - else: - activation = type(self.epilogue_functor.activation_functor) - - self.epilogue_functor = epilogue.get_activation_epilogue( - activation, - datatypes.binding_type(self._element_c), - alignment, - datatypes.binding_type(self._element_accumulator), - datatypes.binding_type(self._element_accumulator), - ) - - @property - def activation(self): - """ - Returns the type of the current activation function used - """ - return type(self.epilogue_functor.activation_functor) - - @activation.setter - def activation(self, act): - """ - Sets the type of the activation function to use - """ - self._reset_epilogue_functor_activation(act) - - @property - def opclass(self) -> cutlass.OpcodeClass: - """ - Returns the opcode class currently in use by the GEMM - - :return: opcode class currently in use - :rtype: cutlass.OpcodeClass - """ - return self.op_class - - @opclass.setter - def opclass(self, oc: cutlass.OpcodeClass): - """ - Sets the opcode class to use in the GEMM. If the opcode class is not supported under - the given compute capability and element/layout combinations of the GEMM, an exception is raised. - """ - if oc in self.possible_op_classes: - self.op_class = oc - else: - raise Exception( - f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' - f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' - f'layout combination ({self._layout_a}, {self._layout_b}).') - - # Changing the op class changes the elements per access in the epilogue. Reset this. - if self.op_class == cutlass.OpcodeClass.Simt: - elements_per_access = 1 - else: - elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] - - if self.epilogue_functor is not None: - self._reset_epilogue_functor_alignment(elements_per_access) - - # Changing the op class also changes the possible operations available. Reset these. - self.possible_operations = self.options.operations( - self.op_class, self._element_a, self._element_b, - self._element_accumulator, self._layout_a, self._layout_b) - @property def swizzling_functor(self): """ @@ -430,7 +328,7 @@ def _valid_tile_description(self, td: TileDescription) -> tuple: """ # Check stage count based on the CC to which we are compiling (self.cc), rather # than the CC from which we find kernels (self.current_cc) - valid, msg = check.valid_stage_count(self.cc, td) + valid, msg = check.valid_stage_count(self.cc, td, self._element_c, self._element_d) if not valid: return (valid, msg) @@ -438,7 +336,7 @@ def _valid_tile_description(self, td: TileDescription) -> tuple: if not valid: return (valid, msg) - valid, msg = check.valid_kernel_schedule(self.current_cc, td.kernel_schedule) + valid, msg = check.valid_schedule(self.current_cc, td.kernel_schedule, td.epilogue_schedule, td.tile_scheduler) return valid, msg def tile_descriptions(self) -> list: @@ -476,7 +374,7 @@ def construct( alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C) - self._reset_epilogue_functor_alignment(alignment_C) + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) tensor_A = TensorDescription( datatypes.binding_type(self._element_a), @@ -562,68 +460,6 @@ def _verify_type_and_layout(self, tensor, ref_type, ref_layout, name): f'does not match the expected type and ' f'layout of ({ref_type}, {ref_layout}).') - def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): - """ - Verifies the following properties: - 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) - 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions - set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) - - If either of these properties does not hold, an exception is raised. If these properties hold and - ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. - - :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in - :type tensor: numpy/cupy/torch array/tensor object - :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in - :type ref_tensor: numpy/cupy/torch array/tensor object - :param ref_dtype: data type for the tensor that this object was initialized to - :param ref_layout: layout for the tensor that this object was initialized to - :param name: identifier of the tensor to verify. Used in raising exceptions - :type name: str - - :return: valid tensor object to use - :rtype: numpy/cupy/torch array/tensor object - """ - if tensor is None: - if ref_tensor is None: - raise Exception(f"Tensor {name} must be set.") - return ref_tensor - - self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) - return tensor - - def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): - """ - Verifies the following properties: - 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) - 2) If ``scalar`` is not ``None``, its datatype must match matches the current version - set by the plan (i.e., those in ``ref_dtype``) - - If either of these properties does not hold, an exception is raised. If these properties hold and - ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. - - :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in - :type scalar: numpy/cupy/torch scalar - :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in - :type ref_scalar: numpy/cupy/torch scalar - :param ref_dtype: data type for the scalar that this object was initialized to - :param name: identifier of the scalar to verify. Used in raising exceptions - :type name: str - - :return: valid scalar to use - :rtype: numpy/cupy/torch scalar - """ - if scalar is None: - if ref_scalar is None: - raise Exception(f"Scalar {name} must be set.") - return ref_scalar - dtype = datatypes.library_type(scalar.dtype) - if dtype != ref_dtype: - raise Exception( - f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." - ) - return scalar - def run(self, A=None, B=None, C=None, D=None, alpha=None, beta=None, batch_count: int = 1, sync: bool = True, print_module: bool = False) -> GemmArguments: diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py index b8261fc168..d7eeb53f43 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass/op/gemm_grouped.py @@ -168,7 +168,7 @@ def construct(self, tile_description: TileDescription = None, alignment_B = check.alignment_or_default(alignment_B, alignment_preference) alignment_C = check.alignment_or_default(alignment_C, alignment_preference) - self._reset_epilogue_functor_alignment(alignment_C) + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) tensor_A = TensorDescription( datatypes.binding_type(self._element_a), diff --git a/python/cutlass/op/op.py b/python/cutlass/op/op.py index cb76b3edf3..be0fb2ae9b 100644 --- a/python/cutlass/op/op.py +++ b/python/cutlass/op/op.py @@ -36,11 +36,13 @@ from bisect import bisect_left -from cutlass import option_registry +import cutlass +from cutlass import option_registry, epilogue from cutlass.backend.utils.device import device_cc from cutlass.epilogue import get_activations from cutlass.library_defaults import _generator_ccs from cutlass.swizzle import get_swizzling_functors +from cutlass.utils import datatypes class OperationBase: @@ -48,22 +50,26 @@ class OperationBase: Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) """ - def __init__(self, cc: int = None, kernel_cc: int = None): + def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = cutlass.OperationKind.Gemm): """ :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 :type cc: int :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 :type kernel_cc: int """ + self.operation_kind = operation_kind self.cc = cc if cc is not None else device_cc() self.specified_kernel_cc = kernel_cc is not None self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) self.tile_description = None - self.options = option_registry.options_for_cc(self.current_cc) + self.options = option_registry.options_for_cc(self.current_cc, operation_kind) if self.options is None: raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}") + + # Default activation function: identity + self._activation = epilogue.identity def _find_closest_cc(self, cc: int) -> int: """ @@ -113,4 +119,210 @@ def _reset_options(self, cc: int): if cc not in _generator_ccs: raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.') self.current_cc = cc - self.options = option_registry.options_for_cc(self.current_cc) + self.options = option_registry.options_for_cc(self.current_cc, self.operation_kind) + + def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): + """ + Verifies the following properties: + 1) Either ``scalar`` or ``ref_scakar`` must be set (i.e., not ``None``) + 2) If ``scalar`` is not ``None``, its datatype must match matches the current version + set by the plan (i.e., those in ``ref_dtype``) + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``scalar`` is not ``None``, ``scalar`` is returned. Otherwise, ``ref_scalar`` is returned. + + :param scalar: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type scalar: numpy/cupy/torch scalar + :param ref_scalar: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_scalar: numpy/cupy/torch scalar + :param ref_dtype: data type for the scalar that this object was initialized to + :param name: identifier of the scalar to verify. Used in raising exceptions + :type name: str + + :return: valid scalar to use + :rtype: numpy/cupy/torch scalar + """ + if scalar is None: + if ref_scalar is None: + raise Exception(f"Scalar {name} must be set.") + return ref_scalar + if hasattr(scalar, "dtype"): + dtype = datatypes.library_type(scalar.dtype) + if dtype != ref_dtype: + raise Exception( + f"Tensor {name} with type {dtype} does not match expected type {ref_dtype}." + ) + return scalar + + def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): + """ + Verifies the following properties: + 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) + 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions + set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) + + If either of these properties does not hold, an exception is raised. If these properties hold and + ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. + + :param tensor: object representing a tensor passed in to verify, or ``None`` if no tensor was passed in + :type tensor: numpy/cupy/torch array/tensor object + :param ref_tensor: object representing a tensor passed in on construction of this object, or ``None`` if no tensor was passed in + :type ref_tensor: numpy/cupy/torch array/tensor object + :param ref_dtype: data type for the tensor that this object was initialized to + :param ref_layout: layout for the tensor that this object was initialized to + :param name: identifier of the tensor to verify. Used in raising exceptions + :type name: str + + :return: valid tensor object to use + :rtype: numpy/cupy/torch array/tensor object + """ + if tensor is None: + if ref_tensor is None: + raise Exception(f"Tensor {name} must be set.") + return ref_tensor + + self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name) + return tensor + + # + # Opcode Related + # + + @property + def opclass(self) -> cutlass.OpcodeClass: + """ + Returns the opcode class currently in use by the GEMM + + :return: opcode class currently in use + :rtype: cutlass.OpcodeClass + """ + return self.op_class + + @opclass.setter + def opclass(self, oc: cutlass.OpcodeClass): + if isinstance(oc, str): + oc = datatypes.getattr_enum(cutlass.OpcodeClass, oc) + if oc in self.possible_op_classes: + self.op_class = oc + else: + raise Exception( + f'Unsupported operation class {oc} for CC {self.cc} and data type combination ' + f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' + f'layout combination ({self._layout_a}, {self._layout_b}).') + + # Changing the op class changes the elements per access in the epilogue. Reset this. + if self.op_class == cutlass.OpcodeClass.Simt: + elements_per_access = 1 + else: + elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] + + if self.epilogue_functor is not None: + self.epilogue_functor = self._reset_epilogue_functor_alignment(elements_per_access, self.epilogue_functor) + + # Changing the op class also changes the possible operations available. Reset these. + self.possible_operations = self.options.operations( + self.op_class, self._element_a, self._element_b, + self._element_accumulator, self._layout_a, self._layout_b) + + # + # Epilogue + # + + def _create_epilogue_functor_activation(self, activation): + """ + Returns the epilogue functor with given activation function + """ + if self.epilogue_functor is None: + if self.op_class == cutlass.OpcodeClass.Simt: + elements_per_access = 1 + else: + elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] + else: + elements_per_access = self.epilogue_functor.epilogue_vector_length + + if not self.specified_kernel_cc: + if self.current_cc == 90 and activation != epilogue.identity: + # CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation, + # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. + cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + self._reset_options(80) + self._reset_operations(reset_epilogue=False) + elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity): + # SM80 fallback kernels are currently used. Since an identity activation is requested, + # we can switch back to using SM90 kernels. + self._reset_options(90) + self._reset_operations(reset_epilogue=False) + else: + if self.current_cc == 90 and activation != epilogue.identity: + raise Exception("Epilogues with elementwise fusion are not currently supported " + "in the Python interface for 3.x kernels. To use 2.x kernels " + "with fused elementwise epilogues, do not set the `kernel_cc` " + "parameter when constructing the Gemm object.") + + return epilogue.get_activation_epilogue( + activation, + datatypes.binding_type(self._element_c), + elements_per_access, + datatypes.binding_type(self._element_accumulator), + datatypes.binding_type(self._element_accumulator), + ) + + def _reset_epilogue_functor_activation(self, activation): + """ + Set the epilogue functor based on the provided activation function + """ + self.epilogue_functor = self._create_epilogue_functor_activation(activation) + + def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor): + """ + Reset the alignment of the current epilogue functor based on alignment C + """ + if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'): + # Identity epilogue does not have 'activation_functor' + activation = epilogue.identity + else: + activation = type(epilogue_functor.activation_functor) + + epilogue_functor = epilogue.get_activation_epilogue( + activation, + datatypes.binding_type(self._element_c), + alignment, + datatypes.binding_type(self._element_accumulator), + datatypes.binding_type(self._element_accumulator), + ) + return epilogue_functor + + @property + def activation(self): + """ + Returns the type of the current activation function used + """ + if hasattr(self.epilogue_functor, "activation_functor"): + return type(self.epilogue_functor.activation_functor) + else: + return epilogue.identity + + @activation.setter + def activation(self, act): + """ + Sets the type of the activation function to use + Activation can come with a set of arguments + + :param act: type of activation function to use + :type act: str or tuple. e.g. "relu", ("leaky_relu", 0.01) + + """ + if isinstance(act, tuple): + if isinstance(act[0], str): + act_fn = getattr(cutlass.backend.epilogue, act[0]) + else: + act_fn = act[0] + self._reset_epilogue_functor_activation(act_fn) + self._activation_args = act[1] + self._activation = act[0] + else: + if isinstance(act, str): + act = getattr(cutlass.backend.epilogue, act) + self._reset_epilogue_functor_activation(act) + self._activation = act + diff --git a/python/cutlass/utils/__init__.py b/python/cutlass/utils/__init__.py index 27c114133d..2d4b703094 100644 --- a/python/cutlass/utils/__init__.py +++ b/python/cutlass/utils/__init__.py @@ -32,9 +32,10 @@ from cutlass.utils.check import ( alignment_or_default, + update_alignment, calculate_smem_usage, calculate_smem_usage_per_stage, valid_cluster_shape, - valid_kernel_schedule, + valid_schedule, valid_stage_count, ) diff --git a/python/cutlass/utils/check.py b/python/cutlass/utils/check.py index 3cd4dd1d5f..6983dbb005 100644 --- a/python/cutlass/utils/check.py +++ b/python/cutlass/utils/check.py @@ -39,29 +39,35 @@ import cutlass_bindings import cutlass from cutlass.backend.library import DataTypeSize, TileDescription +from cutlass.utils.datatypes import binding_type -def calculate_smem_usage_per_stage(tile_description, operation_kind): +def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: cutlass.OperationKind) -> int: """ Returns the amount of shared memory in bytes consumed in a single stage of a kernel. + :param td: tile description to compute shared memory of + :type td: TileDescription + :param operation_kind: identifier for the type of operation being performed + :type operation_kind: cutlass.OperationKind + :return: number of bytes of shared memory consumed by a single stage :rtype: int """ - m, n, k = tile_description.threadblock_shape + m, n, k = td.threadblock_shape if operation_kind == cutlass.OperationKind.Gemm: stage_barrier_bytes = 32 return ( - (DataTypeSize[tile_description.math_instruction.element_a] * m * k // 8) - + (DataTypeSize[tile_description.math_instruction.element_b] * k * n // 8) + (DataTypeSize[td.math_instruction.element_a] * m * k // 8) + + (DataTypeSize[td.math_instruction.element_b] * k * n // 8) + stage_barrier_bytes ) else: raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}") -def calculate_smem_usage(operation): +def calculate_smem_usage(operation) -> int: """ Returns the amount of shared memory in bytes consumed by a kernel. @@ -72,7 +78,11 @@ def calculate_smem_usage(operation): return _per_stage * operation.tile_description.stages -def valid_stage_count(cc: int, td: TileDescription) -> tuple: +def valid_stage_count( + cc: int, + td: TileDescription, + element_C: cutlass.DataType = None, + element_D: cutlass.DataType = None) -> tuple: """ Checks whether a device with `cc` supports the number of stages within `tile_description`, both based on raw limits on the number of stages and based on shared memory capacity @@ -81,15 +91,26 @@ def valid_stage_count(cc: int, td: TileDescription) -> tuple: :type cc: int :param td: tile description to check :type td: TileDescription + :param element_C: data type of operand C + :type element_C: cutlass.DataType + :param element_D: data type of operand D + :type element_D: cutlass.DataType :return: tuple with the first element indicating whether the provided tile description is valid for the provided device and the second element being an error message :rtype: tuple """ - if cc == 90 and (td.stages is None or td.stages == 0): - # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically - # determines the stage count to use. Thus, all settings are valid in these scenarios. - return (True, "") + if cc == 90: + if (td.stages is None or td.stages == 0): + # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically + # determines the stage count to use. Thus, all settings are valid in these scenarios. + return (True, "") + else: + cutlass.logger.warning( + "Setting an explicit stage count for SM90 kernels currently may " + "result in compilation errors if the combination of tile shape, " + "stage count, and shared memory requirement of the epilogue exceeds " + "the available shared memory per SM.") if td.stages <= 0: return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.") @@ -98,14 +119,20 @@ def valid_stage_count(cc: int, td: TileDescription) -> tuple: return (False, f"Tile description has stage count of {td.stages}, " f"but only 2 stages are supported on SM{cc}.") + # The calculation below does not consider shared memory used by the epilogue and, thus, + # only catches cases in which the mainloop exceeds the device's shared memory capacity. + # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the + # mainloop and epilogue is shared. smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm) + smem_usage_mainloop = (smem_per_stage * td.stages) smem_arch = cutlass.SharedMemPerCC[cc] << 10 - if (smem_per_stage * td.stages) > smem_arch: + if smem_usage_mainloop > smem_arch: return ( False, "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n" - f"Details: configuration uses {smem_per_stage} bytes of shared memory per stage, and " - f"{td.stages} stages for a total of {smem_per_stage * td.stages} bytes.\n" - f"The maxmium amoung of shared memory that can be used per block on CC {cc} is {smem_arch}.") + f"Details:\n" + f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and " + f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n" + f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.") return (True, "") @@ -153,21 +180,40 @@ def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple: return (True, "") -def valid_kernel_schedule(cc: int, kernel_schedule: cutlass.KernelScheduleType) -> tuple: +def valid_schedule( + cc: int, + kernel_schedule: cutlass.KernelScheduleType, + epilogue_schedule: cutlass.EpilogueScheduleType, + tile_scheduler: cutlass.TileSchedulerType) -> tuple: """ - Checks whether a device with ``cc`` supports ``kernel_schedule``. + Checks that the kernel and epilogue schedules passed in are a valid combination for + a device of compute capability ``cc``. :param cc: compute capability of device in question :type cc: int :param kernel_schedule: kernel schedule type - :type KernelScheduleType: cutlass.KernelScheduleType + :type kernel_schedule: cutlass.KernelScheduleType + :param epilogue_schedule: epilogue schedule type + :type epilogue_schedule: cutlass.EpilogueScheduleType + :param tile_scheduler: tile scheduler type + :type tile_scheduler: cutlass.TileSchedulerType - :return: tuple with the first element indicating whether the provided kernel schedule is + :return: tuple with the first element indicating whether the provided schedules are valid for the provided device and the second element being an error message :rtype: tuple """ - if kernel_schedule != cutlass.KernelScheduleType.ScheduleAuto and cc < 90: - return (False, "Non-default kernel schedules are only supported on SM90 and beyond") + kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto) + epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto) + tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default) + if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default): + return (False, "Non-default schedules are only supported on SM90 and beyond") + + if (kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto): + return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") + + if not tile_scheduler_default: + if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule != cutlass.KernelScheduleType.TmaWarpSpecializedCooperative): + return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") return (True, "") @@ -190,3 +236,26 @@ def alignment_or_default(alignment_provided: int, default_alignment: int) -> int return alignment_provided return default_alignment + + +def update_alignment(alignment_provided:int, default_alignment: int) -> int: + """ + Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks + that `alignment_provided` does not exceed `default_alignment`. + + :param alignment_provided: alignment preference specified. Can be None. + :type alignment_provided: int + :param default_alignment: alignment to use if `alignment_provided` is None + :type default_alignment: int + + :return: alignment to use + :rtype: int + """ + if alignment_provided is not None: + if alignment_provided > default_alignment: + if alignment_provided % default_alignment == 0: + return default_alignment + raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.") + return alignment_provided + + return default_alignment diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py index 98984e3b2b..2b9eba7620 100644 --- a/python/cutlass/utils/datatypes.py +++ b/python/cutlass/utils/datatypes.py @@ -232,6 +232,8 @@ def library_layout(layout): return cutlass.LayoutType.RowMajor elif layout == cutlass_bindings.ColumnMajor: return cutlass.LayoutType.ColumnMajor + elif layout == cutlass_bindings.TensorNHWC: + return cutlass.LayoutType.TensorNHWC else: raise Exception(f"No conversion available for layout {layout} to library layout.") @@ -251,6 +253,8 @@ def binding_layout(layout): return cutlass_bindings.RowMajor elif layout == cutlass.LayoutType.ColumnMajor: return cutlass_bindings.ColumnMajor + elif layout == cutlass.LayoutType.TensorNHWC: + return cutlass_bindings.TensorNHWC else: raise Exception(f"No conversion available for layout {layout} to Python-bound CUTLASS layout.") @@ -279,6 +283,16 @@ def get_datatype_and_layout(tensor): else: raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") +def get_tensor_shape(tensor): + if (numpy_available and isinstance(tensor, np.ndarray)) or ( + cupy_available and isinstance(tensor, cp.ndarray) + ): + return tensor.shape + elif torch_available and isinstance(tensor, torch.Tensor): + size = tensor.size() + return (size[0], size[2], size[3], size[1]) + else: + raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") def binding_opclass(opclass: cutlass.OpcodeClass): if opclass == cutlass.OpcodeClass.TensorOp: @@ -299,7 +313,9 @@ def backend_math_operation(math_op: cutlass.MathOperation): def construct_backend_td(td: cutlass.TileDescription, - kernel_schedule: cutlass.KernelScheduleType) -> TileDescription: + kernel_schedule: cutlass.KernelScheduleType, + epilogue_schedule: cutlass.EpilogueScheduleType, + tile_scheduler: cutlass.TileSchedulerType) -> TileDescription: mi = td.math_instruction backend_mi = MathInstruction( mi.instruction_shape, @@ -309,8 +325,9 @@ def construct_backend_td(td: cutlass.TileDescription, binding_opclass(mi.opcode_class), backend_math_operation(mi.math_operation) ) + cluster_shape = td.cluster_shape if hasattr(td, "cluster_shape") else [1, 1, 1] return TileDescription(td.threadblock_shape, td.stages, td.warp_count, - backend_mi, td.cluster_shape, kernel_schedule) + backend_mi, cluster_shape, kernel_schedule, epilogue_schedule, tile_scheduler) def td_from_profiler_op(op) -> TileDescription: @@ -322,8 +339,10 @@ def td_from_profiler_op(op) -> TileDescription: :returns: backend TileDescription :rtype: cutlass.backend.TileDescription """ - schedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None - return construct_backend_td(op.tile_description, schedule) + kschedule = op.kernel_schedule if hasattr(op, 'kernel_schedule') else None + eschedule = op.epilogue_schedule if hasattr(op, 'epilogue_schedule') else None + tschedule = op.tile_scheduler if hasattr(op, 'tile_scheduler') else None + return construct_backend_td(op.tile_description, kschedule, eschedule, tschedule) def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription: @@ -336,4 +355,16 @@ def td_from_profiler_td(td: cutlass.backend.TileDescription) -> TileDescription: :returns: backend TileDescription :rtype: cutlass.backend.TileDescription """ - return construct_backend_td(td, kernel_schedule=None) + return construct_backend_td(td, kernel_schedule=None, epilogue_schedule=None, tile_scheduler=None) + +def to_camel_case(snake_str): + return "".join(x.capitalize() for x in snake_str.lower().split("_")) + + +def getattr_enum(obj, attr_name): + # The attr_name is under the snake_case + camel_attr = to_camel_case(attr_name) + if hasattr(obj, camel_attr): + return getattr(obj, camel_attr) + else: + raise Exception(f"Invalid option: {attr_name}") diff --git a/python/setup.py b/python/setup.py index 266228d143..569889849c 100644 --- a/python/setup.py +++ b/python/setup.py @@ -112,6 +112,7 @@ def custom_compile(obj, src, ext, cc_args, extra_postargs, pp_opts): cuda_install_path + '/lib64', ] + ext_modules = [ Pybind11Extension('cutlass_bindings', ['cutlass/cpp/cutlass_bindings.cpp'], diff --git a/test/python/conv2d/conv2d_sm80.py b/test/python/conv2d/conv2d_sm80.py new file mode 100644 index 0000000000..32d897f90c --- /dev/null +++ b/test/python/conv2d/conv2d_sm80.py @@ -0,0 +1,138 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for Conv2d operands on SM80 +""" +from conv2d_test_utils import * +import cutlass +import logging + + +cutlass.set_log_level(logging.WARNING) +cc = 80 + +@unittest.skipIf(device_cc() != cc, 'Device compute capability is invalid for SM80 tests.') +class Conv2dSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + +conv_problems = get_conv_problems() + +# Tests for optimized & analytic +for conv_kind in ["fprop", "wgrad", "dgrad"]: + # F16, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=2, instruction_shape=[1, 1, 1]) + # F16, tensor op + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, analytic iterator + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="analytic") + # F16, tensor op, f32 output + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16]) + # F16, tensor op, different tile description + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8]) + # F32, simt + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, + opclass="simt", threadblock_shape=[128, 128, 8], + warp_count=[4, 2, 1], stages=4, instruction_shape=[1, 1, 1]) + # Tf32, tensorop + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, + opclass="tensor_op", threadblock_shape=[128, 128, 16], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8] + ) + # Split-K + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="serial", + split_k_slices=2) + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode="parallel", + split_k_slices=5) + # Swizzling functor + add_test( + Conv2dSm80, cc, conv_kind, conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 64, 32], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 8], swizzle=4) + +# Tests for few channels and fixed channels +# F16, tensor op, few channels +for c, tb, stage, inst in zip([2, 1], + [[128, 128, 64], [128, 128, 32]], + [3, 2], + [[16, 8, 16], [16, 8, 8]]): + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=tb, + warp_count=[2, 2, 1], stages=stage, instruction_shape=inst, iterator_algorithm="few_channels" + ) +# F16, tensor op, fixed channels +for c in [8, 4, 2]: + add_test( + Conv2dSm80, cc, "fprop", conv2d_few_channel_problemsizes(c), cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], iterator_algorithm="fixed_channels" + ) + +# Test activations +for activation in ["relu", "leaky_relu"]: + for split_k_mode, split_k_slices in zip(["parallel", "serial", "parallel"], [1, 7, 5]): + add_test( + Conv2dSm80, cc, "fprop", conv_problems, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f16, + opclass="tensor_op", threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, instruction_shape=[16, 8, 16], split_k_mode=split_k_mode, + split_k_slices=split_k_slices, activation=activation) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/conv2d/conv2d_test_utils.py b/test/python/conv2d/conv2d_test_utils.py new file mode 100644 index 0000000000..4fc8f0a251 --- /dev/null +++ b/test/python/conv2d/conv2d_test_utils.py @@ -0,0 +1,508 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Util Functions for Conv2d Test +""" +import torch +import cutlass +import unittest +import cutlass_bindings +from cutlass.utils.datatypes import binding_type, binding_opclass +from cutlass.backend.test.conv2d_testbed import Conv2dLauncher, getTensorRef, getTensorView +from cutlass.backend.utils.device import device_cc +from cutlass.backend.test.utils import get_name_conv2d +import numpy as np + +def conv2d_few_channel_problemsizes(channels): + problem_sizes = [ + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 8, 8, channels), + cutlass_bindings.Tensor4DCoord(16, 3, 3, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 16, 16, channels), + cutlass_bindings.Tensor4DCoord(16, 3, 3, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 16, 16, channels), + cutlass_bindings.Tensor4DCoord(16, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(32, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 7, 7, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 5, 5, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 224, 224, channels), + cutlass_bindings.Tensor4DCoord(64, 5, 5, channels), + cutlass_bindings.Tensor4DCoord(1, 1, 1, 1), + cutlass_bindings.MatrixCoord(2, 2), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + return problem_sizes + +torch_dtype = { + cutlass.DataType.f16: torch.float16, + cutlass.DataType.f32: torch.float32, + cutlass.DataType.f64: torch.float64 +} + +numpy_dtype = { + cutlass.DataType.f16: np.float16, + cutlass.DataType.f32: np.float32, + cutlass.DataType.f64: np.float64 +} + + +def validate_problem_size(ps, conv_kind, split_k_slices): + P = (ps.H + 2 * ps.pad_h - ps.dilation_h * (ps.R - 1) - 1) // ps.stride_h + 1 + Q = (ps.W + 2 * ps.pad_w - ps.dilation_w * (ps.S - 1) - 1) // ps.stride_w + 1 + if P != ps.P or Q != ps.Q: + return False + + # Split-K (serial or parallel) is not supported for strided dgrad + if conv_kind == "dgrad" and split_k_slices > 1 and (ps.stride_h > 1 or ps.stride_w > 1): + return False + return True + + +# Override the backend launcher +class Conv2dLauncherFrontend(Conv2dLauncher): + def __init__(self, plan: cutlass.Conv2d, seed: int = 80, backend="numpy"): + self.operation = plan + self.conv_kind = plan.conv_kind + self.seed = seed + self.backend = backend + + self.dtype_A = plan._element_a + self.dtype_B = plan._element_b + self.dtype_C = plan._element_c + self.dtype_acc = plan._element_accumulator + + self.layout_A = cutlass_bindings.TensorNHWC + self.layout_B = cutlass_bindings.TensorNHWC + self.layout_C = cutlass_bindings.TensorNHWC + self.layout_D = cutlass_bindings.TensorNHWC + + self.element_compute = cutlass_bindings.float32 + self.enable_cached_results = True + + # Get randomization_max + if self.dtype_A in [cutlass.DataType.f16, cutlass.DataType.bf16]: + if self.dtype_acc in [cutlass.DataType.f16, cutlass.DataType.bf16]: + self.randomization_max = 2 + else: + self.randomization_max = 3 + else: + self.randomization_max = 7 + + self.activation = plan.activation + + self.host_conv2d = cutlass_bindings.test.conv.host.conv2d + + + def set_seed(self): + if self.backend == "numpy": + np.random.seed(self.seed) + else: + torch.manual_seed(self.seed) + + def uniform_init(self, size, dtype): + if self.backend == "numpy": + return super().uniform_init(size, numpy_dtype[dtype]) + else: + tensor = torch.ceil( + torch.empty(size=size, dtype=torch_dtype[dtype], device="cuda").uniform_(-self.randomization_max - 0.5, self.randomization_max - 0.5) + ).to(memory_format=torch.channels_last) + return tensor + + def zeros_like(self, tensor): + if self.backend == "numpy": + return np.zeros_like(tensor) + else: + return torch.zeros_like(tensor).to(memory_format=torch.channels_last) + + def reference(self, ps, A, B, C, alpha, beta, activation): + if self.backend == "numpy": + numpy_result = self.host_reference(ps, A, B, C, alpha, beta, activation) + return numpy_result + else: + if self.conv_kind == cutlass_bindings.conv.Operator.fprop: + torch_result = alpha * torch.ops.aten.conv2d( + A, + B, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w) + ) + beta * C + elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + torch_result = alpha * torch.nn.grad.conv2d_input( + (ps.N, ps.C, ps.H, ps.W), + B, + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad: + torch_result = alpha * torch.nn.grad.conv2d_weight( + B, + (ps.K, ps.C, ps.R, ps.S), + A, + padding=(ps.pad_h, ps.pad_w), + stride=(ps.stride_h, ps.stride_w) + ) + beta * C + else: + raise Exception(f"Conv kind {self.conv_kind} is currently unsupported.") + + if activation == cutlass.backend.epilogue.relu: + torch_result = torch.nn.functional.relu(torch_result) + elif activation == cutlass.backend.epilogue.leaky_relu: + torch_result = torch.nn.functional.leaky_relu(torch_result, 0.5) + + return torch_result + + def host_reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta, activation): + if self.element_compute == cutlass_bindings.float16: + alpha = cutlass_bindings.float16(alpha) + beta = cutlass_bindings.float16(beta) + elif self.element_compute == cutlass_bindings.int32: + alpha = int(alpha) + beta = int(beta) + else: + alpha = alpha + beta = beta + + # If cached result is loaded + cached_result_loaded = False + + if self.enable_cached_results: + # Get problem key + cached_test_key = cutlass_bindings.test.conv.host.CreateCachedConv2dTestKey( + self.conv_kind, + problem_size, + alpha, + beta, + getTensorView( + tensor_A, self.layout_A, self.conv_kind, problem_size, "a" + ), + getTensorView( + tensor_B, self.layout_B, self.conv_kind, problem_size, "b" + ), + getTensorView( + tensor_C, self.layout_C, self.conv_kind, problem_size, "c" + ), + ) + + cached_test_key.problem = cached_test_key.problem + f"_{activation.tag.split('::')[-1]}" + + cached_test_result = cutlass_bindings.test.conv.host.CachedTestResult() + + conv2d_result_cache_name = "cached_results_SM%d_%d.txt" % ( + self.operation.arch, + self.seed, + ) + + cached_results = cutlass_bindings.test.conv.host.CachedTestResultListing( + conv2d_result_cache_name + ) + # CachedTestResultListing cached_results(conv2d_result_cache_name); + cached = cached_results.find(cached_test_key) + cached_result_loaded = cached[0] + if cached_result_loaded: + cached_test_result = cached[1] + + if not cached_result_loaded: + # Compute the conv2d on host + tensor_D_ref = np.ones_like(tensor_C) + tensor_ref_A = getTensorRef( + tensor_A, self.layout_A, self.conv_kind, problem_size, "a" + ) + tensor_ref_B = getTensorRef( + tensor_B, self.layout_B, self.conv_kind, problem_size, "b" + ) + tensor_ref_C = getTensorRef( + tensor_C, self.layout_C, self.conv_kind, problem_size, "c" + ) + tensor_ref_D_ref = getTensorRef( + tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d" + ) + + self.host_conv2d( + self.conv_kind, + problem_size, + tensor_ref_A, + tensor_ref_B, + tensor_ref_C, + tensor_ref_D_ref, + alpha, + beta, + ) + + if activation == cutlass.backend.epilogue.leaky_relu: + tensor_D_ref = activation.numpy(tensor_D_ref, 0.5) + else: + tensor_D_ref = activation.numpy(tensor_D_ref) + + tensor_view_D_ref = getTensorView( + tensor_D_ref, self.layout_D, self.conv_kind, problem_size, "d" + ) + + if self.enable_cached_results: + cached_test_result.D = cutlass_bindings.test.conv.host.TensorHash( + tensor_view_D_ref + ) + cached_results = ( + cutlass_bindings.test.conv.host.CachedTestResultListing( + conv2d_result_cache_name + ) + ) + cached_results.append(cached_test_key, cached_test_result) + cached_results.write(conv2d_result_cache_name) + else: + return tensor_D_ref + + return cached_test_result.D + + def equal(self, tensor_D, tensor_D_ref, problem_size): + if self.backend == "numpy": + return super().equal(tensor_D, tensor_D_ref, problem_size) + else: + torch.cuda.synchronize() + return torch.equal(tensor_D, tensor_D_ref) + + + def run(self, ps, split_k_mode=cutlass_bindings.conv.SplitKMode.Serial, split_k_slices=1, alpha=1.0, beta=0.0): + + # + # Initialize input and output tensors + # + if self.conv_kind == cutlass_bindings.conv.Operator.fprop: + if self.backend == "torch": + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + else: + tensor_A_size = (ps.N, ps.H, ps.W, ps.C) + tensor_B_size = (ps.K, ps.R, ps.S, ps.C) + tensor_C_size = (ps.N, ps.P, ps.Q, ps.K) + elif self.conv_kind == cutlass_bindings.conv.Operator.dgrad: + if self.backend == "torch": + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + else: + tensor_A_size = (ps.N, ps.P, ps.Q, ps.K) + tensor_B_size = (ps.K, ps.R, ps.S, ps.C) + tensor_C_size = (ps.N, ps.H, ps.W, ps.C) + elif self.conv_kind == cutlass_bindings.conv.Operator.wgrad: + if self.backend == "torch": + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + else: + tensor_A_size = (ps.N, ps.P, ps.Q, ps.K) + tensor_B_size = (ps.N, ps.H, ps.W, ps.C) + tensor_C_size = (ps.K, ps.R, ps.S, ps.C) + else: + raise Exception(f"Conv kind {self.conv_kind} is not supported") + + self.set_seed() + + tensor_A = self.uniform_init(size=tensor_A_size, dtype=self.dtype_A) + tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B) + tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C) + tensor_D = self.zeros_like(tensor_C) + + self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D, + stride=(ps.stride_h, ps.stride_w), + padding=(ps.pad_h, ps.pad_w), + dilation=(ps.dilation_h, ps.dilation_w), + alpha=alpha, beta=beta, + split_k=(split_k_mode, split_k_slices)) + + tensor_D_ref = self.reference( + ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation + ) + + return self.equal(tensor_D, tensor_D_ref, ps) + + +def add_test( + cls, + cc, + conv_kind, + problem_sizes, + element, + element_accumulator, + element_output, + opclass, + threadblock_shape, + warp_count, + instruction_shape, + stages, + iterator_algorithm=None, + swizzle=None, + split_k_mode="serial", + split_k_slices=1, + activation = "identity" +): + """Create a test-running function with the given specification""" + test_name = get_name_conv2d( + cc, conv_kind, element, element_accumulator, + element_output, opclass, threadblock_shape, warp_count, instruction_shape, stages, + iterator_algorithm, swizzle, split_k_mode, split_k_slices, activation) + + def run(self): + # Create the plan + plan = cutlass.Conv2d( + kind=conv_kind, + element=element, + element_accumulator=element_accumulator, + element_C=element_output, + element_D=element_output + ) + + # Set the opclass + plan.opclass = opclass + # Set the tile description + td = { + "threadblock_shape": threadblock_shape, + "warp_count": warp_count, + "stages": stages, + "instruction_shape": instruction_shape, + } + + plan.tile_description = td + # Set iterator algorithm + if iterator_algorithm is not None: + plan.iterator_algorithm = iterator_algorithm + # Set swizzling functor + if swizzle is not None: + plan.swizzling_stride = swizzle + + if activation != "identity": + if activation == "leaky_relu": + plan.activation = (cutlass.epilogue.leaky_relu, 0.5) + else: + plan.activation = getattr(cutlass.epilogue, activation) + + conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="numpy") + + for ps in problem_sizes: + if not validate_problem_size(ps, conv_kind, split_k_slices): continue + + self.assertTrue( + conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 0.5) + ) + + setattr(cls, test_name, run) + + return run + + +def get_conv_problems(): + # 64: minimum channel size + conv_problems = list(cutlass_bindings.test.conv.TestbedConv2dProblemSizes(64).conv2d_default_sizes) + # Insert alignment 4 & 2 tests + conv_problems += [ + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 12), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 12), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 14), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 14), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 23, 56, 98), + cutlass_bindings.Tensor4DCoord(128, 3, 3, 98), + cutlass_bindings.Tensor4DCoord(4, 0, 5, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ), + ] + + return conv_problems diff --git a/test/python/conv2d/run_all_tests.py b/test/python/conv2d/run_all_tests.py new file mode 100644 index 0000000000..63bbefb30a --- /dev/null +++ b/test/python/conv2d/run_all_tests.py @@ -0,0 +1,42 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + tests = loader.discover('./', 'conv2d_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/test/python/emit/pytorch.py b/test/python/emit/pytorch.py index 3ac1c9b05c..c1d8a591ff 100644 --- a/test/python/emit/pytorch.py +++ b/test/python/emit/pytorch.py @@ -39,6 +39,7 @@ import unittest import cutlass +import cutlass_bindings if cutlass.utils.datatypes.torch_available: import torch @@ -85,6 +86,34 @@ def _generate_problems(dtype, num): Ds.append(D) return As, Bs, Cs, Ds +def _generate_conv2d_problem(conv_kind, dtype, ps): + """ + Utility function to generate conv2d inputs + + :param conv_kind: kind of convolution + :type conv_kind: str + :param dtype: data type of tensors + :param problem_size: the conv2d problem size + :type problem_size: cutlass_bindings.conv.Conv2dProblemSize + + :return: initialized tensors A, B, C, and D + :rtype: list + """ + if conv_kind == "fprop": + tensor_A_size = (ps.N, ps.C, ps.H, ps.W) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.K, ps.P, ps.Q) + elif conv_kind == "dgrad": + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.K, ps.C, ps.R, ps.S) + tensor_C_size = (ps.N, ps.C, ps.H, ps.W) + else: + tensor_A_size = (ps.N, ps.K, ps.P, ps.Q) + tensor_B_size = (ps.N, ps.C, ps.H, ps.W) + tensor_C_size = (ps.K, ps.C, ps.R, ps.S) + sizes = [tensor_A_size, tensor_B_size, tensor_C_size] + return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes] + @unittest.skipIf(not cutlass.utils.datatypes.torch_available, 'PyTorch must be available to run PyTorch extension tests') class PyTorchExtensionTest(unittest.TestCase): @@ -155,6 +184,127 @@ def check_all(X, Y): Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)] Ds = mod.run(As, Bs, Cs, alpha, beta) check_all(Ds, Ds_ref) + + def test_conv2d_fprop(self): + torch.manual_seed(2023) + + dtype = torch.float16 + plan = cutlass.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) + plan.activation = "relu" + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 16), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 16), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + + D_ref = alpha * torch.ops.aten.conv2d( + A, B, stride=stride, padding=padding + ) + beta * C + D_ref = torch.nn.functional.relu(D_ref) + D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) + + + def test_conv2d_dgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 16), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 16), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W) + D_ref = alpha * torch.nn.grad.conv2d_input( + input_size, B, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, ) + + assert torch.allclose(D, D_ref) + + def test_conv2d_wgrad(self): + torch.manual_seed(2023) + dtype = torch.float16 + plan = cutlass.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) + + op = plan.construct() + with tempfile.TemporaryDirectory() as tmpdir: + mod = cutlass.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) + + problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(1, 4, 4, 16), + cutlass_bindings.Tensor4DCoord(8, 3, 3, 16), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), + cutlass_bindings.MatrixCoord(3, 3), + cutlass_bindings.MatrixCoord(1, 1), + cutlass_bindings.conv.Mode.cross_correlation, + 1, 1 + ) + + A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size) + stride = (problem_size.stride_h, problem_size.stride_w) + padding = (problem_size.pad_h, problem_size.pad_w) + + alpha = 1.0 + beta = 0.5 + weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S) + D_ref = alpha * torch.nn.grad.conv2d_weight( + B, weight_size, A, + stride=stride, padding=padding + ) + beta * C + D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta) + + assert torch.allclose(D, D_ref) + + # Test serial split-K + D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) + assert torch.allclose(D, D_serial_split_k) + + # Test parallel split-K + D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) + assert torch.allclose(D, D_parallel_split_k) if __name__ == '__main__': diff --git a/test/python/gemm/gemm_f16_sm80.py b/test/python/gemm/gemm_f16_sm80.py index 0c32fa5295..39174a0e5d 100644 --- a/test/python/gemm/gemm_f16_sm80.py +++ b/test/python/gemm/gemm_f16_sm80.py @@ -37,81 +37,15 @@ from functools import partial import cutlass -from cutlass.utils.datatypes import binding_opclass, binding_type -from cutlass.backend.test.gemm_testbed import test_all_gemm +import logging import unittest -from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm from cutlass.backend.utils.device import device_cc -cc = 80 - -# Partial specialziation for naming tests -bound_type = binding_type(cutlass.DataType.f16) -name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) - - -def add_test(cls, layouts, alignments, element_output, element_accumulator, - threadblock_shape, warp_count, stages, opclass, swizzle=None): - """ - Create a test-running function with the given specification and set it as a method of `cls`. - - :param cls: class to which the generated method will be added - :type cls: type - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass - :param swizzle: threadblock swizzling functor - """ - cluster_shape = [1, 1, 1] - - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - element_A = cutlass.DataType.f16 - element_B = cutlass.DataType.f16 - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator, - kernel_cc=cc) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - td = plan.tile_descriptions()[0] - td.threadblock_shape = threadblock_shape - td.stages = stages - td.warp_count = warp_count - td.cluster_shape = cluster_shape - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - self.assertTrue(test_all_gemm(op, 'universal')) - - element_epilogue = element_accumulator - name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), - binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) - setattr(cls, name, run) - - return run +cutlass.set_log_level(logging.WARNING) +cc = 80 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') class GemmF16Sm80(unittest.TestCase): @@ -128,40 +62,64 @@ class GemmF16Sm80StreamK(unittest.TestCase): """ pass +add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1]) # Tests using TensorOp -add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) - -add_test_tensorop(GemmF16Sm80, LayoutCombination.NNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.NNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.NTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.NTT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TTT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 128, 32], [1, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 64, 32], [2, 1, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 64], [1, 1, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 64], [1, 1, 1], 5) -add_test_tensorop(GemmF16Sm80, LayoutCombination.TNT, [2, 2, 2], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 32], [2, 2, 1], 3) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) # Tests using SIMT -add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) - -add_test_simt(GemmF16Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 8], [2, 2, 1], 2) -add_test_simt(GemmF16Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [64, 128, 8], [1, 2, 1], 2) -add_test_simt(GemmF16Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [128, 64, 8], [2, 1, 1], 2) -add_test_simt(GemmF16Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 8], [1, 1, 1], 2) -add_test_simt(GemmF16Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f16, [128, 128, 8], [2, 2, 1], 2) +add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(GemmF16Sm80StreamK, LayoutCombination.NNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_streamk(GemmF16Sm80StreamK, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [64, 64, 64], [1, 1, 1], 5) +add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) if __name__ == '__main__': unittest.main() diff --git a/test/python/gemm/gemm_f16_sm90.py b/test/python/gemm/gemm_f16_sm90.py index 8b5ce3f5ed..90236d0554 100644 --- a/test/python/gemm/gemm_f16_sm90.py +++ b/test/python/gemm/gemm_f16_sm90.py @@ -37,86 +37,15 @@ from functools import partial import cutlass -from cutlass.utils.datatypes import binding_opclass, binding_type -from cutlass.backend.test.gemm_testbed import test_all_gemm +import logging import unittest -from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm from cutlass.backend.utils.device import device_cc -cc = 90 - -# Partial specialziation for naming tests -bound_type = binding_type(cutlass.DataType.f16) -name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) - - -def add_test(cls, layouts, alignments, element_output, element_accumulator, - cluster_shape, threadblock_shape, stages, opclass, - kernel_schedule=cutlass.KernelScheduleType.ScheduleAuto, - swizzle=None): - """ - Create a test-running function with the given specification and set it as a method of `cls`. - - :param cls: class to which the generated method will be added - :type cls: type - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType - :param cluster_shape: dimensions of threadblock cluster - :type cluster_shape: list or tuple - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass - :param kernel_schedule: kernel schedule type - :type kernel_schedule: cutlass.KernelScheduleType - :param swizzle: threadblock swizzling functor - """ - - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - element_A = cutlass.DataType.f16 - element_B = cutlass.DataType.f16 - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - td = plan.tile_descriptions()[0] - td.threadblock_shape = threadblock_shape - td.stages = stages - td.cluster_shape = cluster_shape - td.kernel_schedule = kernel_schedule - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - self.assertTrue(test_all_gemm(op, 'universal')) - - element_epilogue = element_accumulator - name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), - binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, - opclass=binding_opclass(opclass), kernel_schedule=kernel_schedule) - setattr(cls, name, run) - - return run +cutlass.set_log_level(logging.WARNING) +cc = 90 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') class GemmF16Sm90(unittest.TestCase): @@ -126,47 +55,85 @@ class GemmF16Sm90(unittest.TestCase): pass -add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) +add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=cutlass.DataType.f16, + warp_count=None, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) # Tests with 1x1x1 clusters -add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], 3) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NTT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [4, 4, 8], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 32], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [64, 64, 64], 5) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNT, [2, 2, 2], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 32], None) +add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1]) +add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) +add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) # Tests with different cluster shapes -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 8], cutlass.DataType.f16, cutlass.DataType.f16, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TNN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.NNN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 2, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [1, 4, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [2, 4, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [4, 1, 1], [64, 128, 64], None) -add_test_tensorop(GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, [4, 2, 1], [64, 128, 64], None) +add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f16, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, cluster_shape=[1, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, cluster_shape=[2, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, cluster_shape=[4, 1, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, cluster_shape=[4, 2, 1]) # Tests for different schedule modes -add_test_schedule = partial(add_test, GemmF16Sm90, LayoutCombination.TTN, [8, 8, 4], cutlass.DataType.f32, cutlass.DataType.f32, opclass=cutlass.OpcodeClass.TensorOp) -add_test_schedule([1, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong) -add_test_schedule([1, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) -add_test_schedule([2, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong) -add_test_schedule([2, 1, 1], [128, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) -add_test_schedule([2, 1, 1], [256, 128, 64], None, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) -add_test_schedule([2, 1, 1], [128, 128, 64], 5, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong) -add_test_schedule([2, 1, 1], [128, 128, 64], 5, kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative) +add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4], + element_output=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, + opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative +) # Tests using SIMT -add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(GemmF16Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 128, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [64, 128, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [128, 64, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f32, [1, 1, 1], [64, 64, 8], 2) -add_test_simt(GemmF16Sm90, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f16, cutlass.DataType.f16, [1, 1, 1], [128, 128, 8], 2) +add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) +add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8]) +add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8]) +add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8]) +add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8]) +add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8]) if __name__ == '__main__': diff --git a/test/python/gemm/gemm_f32_sm80.py b/test/python/gemm/gemm_f32_sm80.py index beb19f5073..f03ef737bd 100644 --- a/test/python/gemm/gemm_f32_sm80.py +++ b/test/python/gemm/gemm_f32_sm80.py @@ -37,82 +37,15 @@ from functools import partial import cutlass -from cutlass.utils.datatypes import binding_opclass, binding_type -from cutlass.backend.test.gemm_testbed import test_all_gemm +import logging import unittest -from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm from cutlass.backend.utils.device import device_cc -cc = 80 - -# Partial specialziation for naming tests -bound_type = binding_type(cutlass.DataType.f32) -name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) - - -def add_test(cls, layouts, alignments, element_output, element_accumulator, - threadblock_shape, warp_count, stages, opclass, swizzle=None): - """ - Create a test-running function with the given specification and set it as a method of `cls`. - - :param cls: class to which the generated method will be added - :type cls: type - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass - :param swizzle: threadblock swizzling functor - """ - cluster_shape = [1, 1, 1] - - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - element_A = cutlass.DataType.f32 - element_B = cutlass.DataType.f32 - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator, - kernel_cc=cc) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - td = plan.tile_descriptions()[0] - td.threadblock_shape = threadblock_shape - td.stages = stages - td.warp_count = warp_count - td.cluster_shape = cluster_shape - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - - self.assertTrue(test_all_gemm(op, 'universal')) - - element_epilogue = element_accumulator - name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), - binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) - setattr(cls, name, run) - - return run +cutlass.set_log_level(logging.WARNING) +cc = 80 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') class GemmF32Sm80(unittest.TestCase): @@ -130,25 +63,37 @@ class GemmF32Sm80StreamK(unittest.TestCase): pass -# Tests using TensorOp -add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) +add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f32, cc=cc, cluster_shape=[1, 1, 1]) -add_test_tensorop(GemmF32Sm80, LayoutCombination.NNN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF32Sm80, LayoutCombination.NNT, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) -add_test_tensorop(GemmF32Sm80, LayoutCombination.NTN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [64, 128, 32], [1, 2, 1], 3) -add_test_tensorop(GemmF32Sm80, LayoutCombination.NTN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [64, 64, 32], [1, 1, 1], 4) +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) # Tests using SIMT -add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) - -add_test_simt(GemmF32Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 8], [2, 2, 1], 2) -add_test_simt(GemmF32Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [64, 128, 8], [1, 2, 1], 2) -add_test_simt(GemmF32Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [128, 64, 8], [2, 1, 1], 2) -add_test_simt(GemmF32Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [64, 64, 8], [1, 1, 1], 2) -add_test_simt(GemmF32Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 8], [2, 2, 1], 2) +add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(GemmF32Sm80StreamK, LayoutCombination.TTN, [4, 4, 4], cutlass.DataType.f32, cutlass.DataType.f32, [128, 128, 32], [2, 2, 1], 3) +add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) if __name__ == '__main__': diff --git a/test/python/gemm/gemm_f64_sm80.py b/test/python/gemm/gemm_f64_sm80.py index 10c43ddf2e..e1fc5d78c2 100644 --- a/test/python/gemm/gemm_f64_sm80.py +++ b/test/python/gemm/gemm_f64_sm80.py @@ -37,83 +37,15 @@ from functools import partial import cutlass -from cutlass.utils.datatypes import binding_opclass, binding_type -from cutlass.backend.test.gemm_testbed import test_all_gemm +import logging import unittest -from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm from cutlass.backend.utils.device import device_cc -cc = 80 - -# Partial specialziation for naming tests -bound_type = binding_type(cutlass.DataType.f64) -name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) - - -def add_test(cls, layouts, alignments, element_output, element_accumulator, - threadblock_shape, warp_count, stages, opclass, swizzle=None): - """ - Create a test-running function with the given specification and set it as a method of `cls`. - - :param cls: class to which the generated method will be added - :type cls: type - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass - :param swizzle: threadblock swizzling functor - """ - - cluster_shape = [1, 1, 1] - - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - element_A = cutlass.DataType.f64 - element_B = cutlass.DataType.f64 - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator, - kernel_cc=cc) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - td = plan.tile_descriptions()[0] - td.threadblock_shape = threadblock_shape - td.stages = stages - td.warp_count = warp_count - td.cluster_shape = cluster_shape - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - - self.assertTrue(test_all_gemm(op, 'universal')) - - element_epilogue = element_accumulator - name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), - binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) - setattr(cls, name, run) - - return run +cutlass.set_log_level(logging.WARNING) +cc = 80 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') class GemmF64Sm80(unittest.TestCase): @@ -131,25 +63,36 @@ class GemmF64Sm80StreamK(unittest.TestCase): pass +add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.f64, cc=cc, cluster_shape=[1, 1, 1]) + # Tests using TensorOp -add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(GemmF64Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 16], [4, 2, 1], 3) -add_test_tensorop(GemmF64Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [64, 64, 16], [2, 2, 1], 4) -add_test_tensorop(GemmF64Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [32, 32, 16], [2, 1, 1], 5) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) # Tests using SIMT -add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) - -add_test_simt(GemmF64Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 8], [2, 2, 1], 2) -add_test_simt(GemmF64Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [64, 128, 8], [1, 2, 1], 2) -add_test_simt(GemmF64Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 64, 8], [2, 1, 1], 2) -add_test_simt(GemmF64Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [64, 64, 8], [1, 1, 1], 2) -add_test_simt(GemmF64Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 8], [2, 2, 1], 2) +add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(GemmF64Sm80StreamK, LayoutCombination.NTT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [128, 128, 16], [4, 2, 1], 3) +add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) if __name__ == '__main__': diff --git a/test/python/gemm/gemm_f64_sm90.py b/test/python/gemm/gemm_f64_sm90.py index 4a51df9943..7626bafc57 100644 --- a/test/python/gemm/gemm_f64_sm90.py +++ b/test/python/gemm/gemm_f64_sm90.py @@ -37,89 +37,15 @@ from functools import partial import cutlass -from cutlass.utils.datatypes import binding_opclass, binding_type -from cutlass.backend.test.gemm_testbed import test_all_gemm +import logging import unittest -from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm from cutlass.backend.utils.device import device_cc -cc = 90 - -# Partial specialziation for naming tests -bound_type = binding_type(cutlass.DataType.f64) -name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) - - -def add_test(cls, layouts, alignments, element_output, element_accumulator, - cluster_shape, threadblock_shape, stages, opclass, persistent=False, swizzle=None): - """ - Create a test-running function with the given specification and set it as a method of `cls`. - - :param cls: class to which the generated method will be added - :type cls: type - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType - :param cluster_shape: dimensions of threadblock cluster - :type cluster_shape: list or tuple - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass - :param persistent: whether this is a persistent warp-specialized kernel - :type persistent: bool - :param swizzle: threadblock swizzling functor - """ - - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - element_A = cutlass.DataType.f64 - element_B = cutlass.DataType.f64 - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - td = plan.tile_descriptions()[0] - td.threadblock_shape = threadblock_shape - td.stages = stages - td.cluster_shape = cluster_shape - td.persistent = persistent - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - self.assertTrue(test_all_gemm(op, 'universal')) - - if persistent: - suffix = "_persistent" - else: - suffix = "" - - element_epilogue = element_accumulator - name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), - binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, - opclass=binding_opclass(opclass), suffix=suffix) - setattr(cls, name, run) - - return run +cutlass.set_log_level(logging.WARNING) +cc = 90 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') class GemmF64Sm90(unittest.TestCase): @@ -129,13 +55,14 @@ class GemmF64Sm90(unittest.TestCase): pass -add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) -add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) +add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], + element=cutlass.DataType.f64, element_output=cutlass.DataType.f64, + element_accumulator=cutlass.DataType.f64, compilation_modes=['nvcc']) -add_test_tensorop(GemmF64Sm90, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [128, 128, 32], 3) -add_test_tensorop(GemmF64Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [128, 128, 32], 3) -add_test_simt(GemmF64Sm90, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [128, 128, 8], 2) -add_test_simt(GemmF64Sm90, LayoutCombination.TTT, [1, 1, 1], cutlass.DataType.f64, cutlass.DataType.f64, [1, 1, 1], [64, 128, 8], 2) +add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized(opclass=cutlass.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized( opclass=cutlass.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) +add_test_specialized( opclass=cutlass.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) if __name__ == '__main__': diff --git a/test/python/gemm/gemm_s8_sm80.py b/test/python/gemm/gemm_s8_sm80.py index 128f5e5825..3ca2e67449 100644 --- a/test/python/gemm/gemm_s8_sm80.py +++ b/test/python/gemm/gemm_s8_sm80.py @@ -37,83 +37,15 @@ from functools import partial import cutlass -from cutlass.utils.datatypes import binding_opclass, binding_type -from cutlass.backend.test.gemm_testbed import test_all_gemm +import logging import unittest -from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm from cutlass.backend.utils.device import device_cc -cc = 80 - -# Partial specialziation for naming tests -bound_type = binding_type(cutlass.DataType.s8) -name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) - - -def add_test(cls, layouts, alignments, element_output, element_accumulator, - threadblock_shape, warp_count, stages, opclass, swizzle=None): - """ - Create a test-running function with the given specification and set it as a method of `cls`. - - :param cls: class to which the generated method will be added - :type cls: type - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass - :param swizzle: threadblock swizzling functor - """ - - cluster_shape = [1, 1, 1] - - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - element_A = cutlass.DataType.s8 - element_B = cutlass.DataType.s8 - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator, - kernel_cc=cc) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - td = plan.tile_descriptions()[0] - td.threadblock_shape = threadblock_shape - td.stages = stages - td.warp_count = warp_count - td.cluster_shape = cluster_shape - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - - self.assertTrue(test_all_gemm(op, 'universal')) - - element_epilogue = element_accumulator - name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), - binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, opclass=binding_opclass(opclass)) - setattr(cls, name, run) - - return run +cutlass.set_log_level(logging.WARNING) +cc = 80 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') class GemmS8Sm80(unittest.TestCase): @@ -131,25 +63,36 @@ class GemmS8Sm80StreamK(unittest.TestCase): pass +add_test_specialized = partial(add_test_gemm, element=cutlass.DataType.s8, cc=cc, cluster_shape=[1, 1, 1]) + # Tests using TensorOp -add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(GemmS8Sm80, LayoutCombination.TNN, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [256, 128, 64], [4, 2, 1], 3) -add_test_tensorop(GemmS8Sm80, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [128, 256, 64], [2, 4, 1], 3) -add_test_tensorop(GemmS8Sm80, LayoutCombination.TNN, [16, 16, 4], cutlass.DataType.s32, cutlass.DataType.s32, [64, 64, 64], [1, 1, 1], 4) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) # Tests using SIMT -add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) - -add_test_simt(GemmS8Sm80, LayoutCombination.NNN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [128, 128, 8], [2, 2, 1], 2) -add_test_simt(GemmS8Sm80, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [64, 128, 8], [1, 2, 1], 2) -add_test_simt(GemmS8Sm80, LayoutCombination.NTN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [128, 64, 8], [2, 1, 1], 2) -add_test_simt(GemmS8Sm80, LayoutCombination.TTN, [1, 1, 1], cutlass.DataType.s32, cutlass.DataType.s32, [64, 64, 8], [1, 1, 1], 2) -add_test_simt(GemmS8Sm80, LayoutCombination.NNT, [1, 1, 1], cutlass.DataType.s32, cutlass.DataType.s32, [128, 128, 8], [2, 2, 1], 2) +add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) + +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests -add_test_streamk = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(GemmS8Sm80StreamK, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [128, 256, 64], [2, 4, 1], 3) +add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) if __name__ == '__main__': diff --git a/test/python/gemm/gemm_s8_sm90.py b/test/python/gemm/gemm_s8_sm90.py index 376c80b530..2ea4ddd22b 100644 --- a/test/python/gemm/gemm_s8_sm90.py +++ b/test/python/gemm/gemm_s8_sm90.py @@ -37,89 +37,15 @@ from functools import partial import cutlass -from cutlass.utils.datatypes import binding_opclass, binding_type -from cutlass.backend.test.gemm_testbed import test_all_gemm +import logging import unittest -from cutlass.backend.test.utils import LayoutCombination, get_name +from cutlass.backend.test.utils import LayoutCombination, add_test_gemm from cutlass.backend.utils.device import device_cc -cc = 90 - -# Partial specialziation for naming tests -bound_type = binding_type(cutlass.DataType.s8) -name_fn = partial(get_name, element_a=bound_type, element_b=bound_type, arch=cc) - - -def add_test(cls, layouts, alignments, element_output, element_accumulator, - cluster_shape, threadblock_shape, stages, opclass, persistent=False, swizzle=None): - """ - Create a test-running function with the given specification and set it as a method of `cls`. - - :param cls: class to which the generated method will be added - :type cls: type - :param layouts: layouts of A, B, and C operands - :type layouts: list or tuple - :param alignments: alingments of A, B, and C operands - :type alignments: list or tuple - :param element_output: data type of the output element - :type element_output: cutlass.DataType - :param element_accumulator: data type used in accumulation - :type element_accumulator: cutlass.DataType - :param cluster_shape: dimensions of threadblock cluster - :type cluster_shape: list or tuple - :param threadblock_shape: dimensions of threadblock tiles - :type threadblock_shape: list or tuple - :param warp_count: warps to be launched per threadblock dimension - :type warp_count: list or tuple - :param stages: number of pipeline stages to use in the kernel - :type stages: int - :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) - :type opclass: cutlass.OpClass - :param persistent: whether this is a persistent warp-specialized kernel - :type persistent: bool - :param swizzle: threadblock swizzling functor - """ - - def run(self): - """ - Dynamically-generated function that constructs a GEMM operation and verifies it against - multiple test cases. - """ - element_A = cutlass.DataType.s8 - element_B = cutlass.DataType.s8 - layout_A, layout_B, layout_C = layouts - alignment_A, alignment_B, alignment_C = alignments - - plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, - layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, - element_accumulator=element_accumulator) - - plan.opclass = opclass - if swizzle is not None: - plan.swizzling_functor = swizzle - td = plan.tile_descriptions()[0] - td.threadblock_shape = threadblock_shape - td.stages = stages - td.cluster_shape = cluster_shape - td.persistent = persistent - op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) - self.assertTrue(test_all_gemm(op, 'universal')) - - if persistent: - suffix = "_persistent" - else: - suffix = "" - - element_epilogue = element_accumulator - name = name_fn(layouts, alignments, binding_type(element_output), binding_type(element_accumulator), - binding_type(element_epilogue), cluster_shape, threadblock_shape, stages, - opclass=binding_opclass(opclass), suffix=suffix) - setattr(cls, name, run) - - return run +cutlass.set_log_level(logging.WARNING) +cc = 90 @unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') class GemmS8Sm90(unittest.TestCase): @@ -129,26 +55,40 @@ class GemmS8Sm90(unittest.TestCase): pass -add_test_tensorop = partial(add_test, opclass=cutlass.OpcodeClass.TensorOp) +add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=cutlass.DataType.s8, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) # Tests with 1x1x1 clusters -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNN, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], 3) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 8], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [64, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 64, 32], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [4, 4, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [128, 128, 128], None) +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) # Tests with different cluster shapes -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [2, 2, 1], [128, 128, 128], None) -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [1, 4, 1], [128, 128, 128], None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) -# Tests with persistent warp-specialized threadblocks -add_test_tensorop(GemmS8Sm90, LayoutCombination.TNT, [16, 16, 16], cutlass.DataType.s8, cutlass.DataType.s32, [2, 1, 1], [128, 128, 128], None, persistent=True) +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass.EpilogueScheduleType.TmaWarpSpecialized) # Tests for SIMT -add_test_simt = partial(add_test, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(GemmS8Sm90, LayoutCombination.TNN, [1, 1, 1], cutlass.DataType.s8, cutlass.DataType.s32, [1, 1, 1], [64, 32, 8], 2) +add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, + element_accumulator=cutlass.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) if __name__ == '__main__': diff --git a/test/python/interface/conv2d_interface.py b/test/python/interface/conv2d_interface.py new file mode 100644 index 0000000000..9667979621 --- /dev/null +++ b/test/python/interface/conv2d_interface.py @@ -0,0 +1,285 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level Conv2d interface +""" + +from math import ceil +import unittest + +import cutlass +import cutlass_bindings +import cutlass.utils.datatypes as datatypes +from cutlass.backend.utils.device import device_cc +from utils import ExpectException +import os + + +class Conv2dEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Conv2d interface + """ + def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator, + alignment_A, alignment_B, alignment_C): + + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + + self.conv_kind = conv_kind + + self.plan = cutlass.op.Conv2d( + kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator) + + self.op = self.plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default Conv2d + :type other_plan: cutlass.op.Conv2d + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types + and layouts for constructing the Conv2d interface + """ + if not datatypes.numpy_available: + return + + # Test when specifying all parameters + plan_other = cutlass.op.Conv2d( + kind=self.conv_kind, + element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass.op.Conv2d( + kind=self.conv_kind, + element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors using generic element and output + plan_other = cutlass.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator are equal + if self.element_C == self.element_accumulator: + plan_other = cutlass.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test with only the generic types. Only rune if the types of A, B, C, and D are the same + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator): + plan_other = cutlass.op.Conv2d(kind=self.conv_kind, element=self.element_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend + """ + if not datatypes.numpy_available: + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + size = (2, 2) + A = np.zeros(size, dtype=type_A) + B = np.zeros(size, dtype=type_B) + C = np.zeros(size, dtype=type_C) + D = np.zeros(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def torch_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend + """ + if not datatypes.torch_available: + return + + import torch + type_A = datatypes.torch_type(self.element_A) + type_B = datatypes.torch_type(self.element_B) + type_C = datatypes.torch_type(self.element_C) + type_D = datatypes.torch_type(self.element_D) + type_accum = datatypes.torch_type(self.element_accumulator) + + size = (2, 2) + + A = torch.empty(size, dtype=type_A) + B = torch.empty(size, dtype=type_B) + C = torch.empty(size, dtype=type_C) + D = torch.empty(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): + # Test when specifying all parameters via tensors + plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + if type_A == type_B: + plan_np = cutlass.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): + plan_np = cutlass.op.Conv2d(kind=self.conv_kind, element=type_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + self.torch_test() + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class ConvEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Conv2d interface + """ + pass + +type2alignment = { + cutlass.DataType.f16: 8, + cutlass.DataType.f32: 4 +} + +def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): + + test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}" + + def run(self): + conv2d_eq = Conv2dEquivalence( + conv_kind=conv_kind, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + element_accumulator=element_accumulator, + alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B], + alignment_C=type2alignment[element_C] + ) + conv2d_eq.test_all() + + setattr(ConvEquivalenceTest, test_name, run) + +for conv_kind in ["fprop", "wgrad", "dgrad"]: + for types in [ + [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16], + [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32], + [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f16], + [cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32], + [cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32] + ]: + add_test(conv_kind, types[0], types[1], types[2], types[3], types[4]) + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class Conv2dErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16) + + with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16) + + td = plan.tile_descriptions()[0] + td.threadblock_shape=[17, 32, 5] + + plan.tile_description = td + with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'): + plan.compile() + # Clean up the error message + os.remove("./cutlass_python_compilation_device_error.txt") + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/interface/gemm_interface.py b/test/python/interface/gemm_interface.py index 7696a5b08d..d8b7d648be 100644 --- a/test/python/interface/gemm_interface.py +++ b/test/python/interface/gemm_interface.py @@ -41,6 +41,7 @@ import cutlass_bindings import cutlass.utils.datatypes as datatypes from cutlass.backend.utils.device import device_cc +from utils import ExpectException class GemmEquivalence: @@ -220,38 +221,6 @@ def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): gemm_eq.test_all() -class ExpectException: - """ - Utility class to assert that an exception was raised when expected - - Example: - - .. highlight:: python - .. code-block:: python - - with ExceptionExpected(True, 'Division by zero'): - x = 1.0 / 0.0 - - :param exception_expected: whether an exception is expected to be raised - :type exception_expected: bool - :param message: message to print if an exception is raised when not expected or vice versa - :type message: str - """ - def __init__(self, exception_expected: bool, message: str = ''): - self.exception_expected = exception_expected - self.message = message - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, traceback): - exception_raised = exc_type is not None - assert self.exception_expected == exception_raised, self.message - - # Suppress the exception - return True - - class GemmErrorTests(unittest.TestCase): """ Tests various error scenarios that arise with the high-level Gemm interface @@ -316,9 +285,22 @@ def test_invalid_tile_description(self): td.stages = 0 plan.construct(td) - with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): - td.stages = 3 - plan.construct(td) + if cc < 90: + with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): + td.stages = 3 + plan.construct(td) + else: + original_kschedule = td.kernel_schedule + original_eschedule = td.epilogue_schedule + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass.EpilogueScheduleType.NoSmemWarpSpecialized + td.stages = 3 + plan.construct(td) + + # Reset schedules + td.kernel_schedule = original_kschedule + td.epilogue_schedule = original_eschedule with ExpectException(True, f'Requested too many stages'): td.stages = 100 @@ -335,9 +317,25 @@ def test_invalid_tile_description(self): # Reset cluster shape td.cluster_shape = cluster_shape - kernel_schedule = td.kernel_schedule - with ExpectException(cc < 90, f'Requested a persistent kernel on SM{cc}'): + with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): + td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(True, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass.EpilogueScheduleType.ScheduleAuto + plan.construct(td) + + with ExpectException(True, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): + td.kernel_schedule = cutlass.KernelScheduleType.ScheduleAuto + td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): + td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative + td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative + td.tile_scheduler = cutlass.TileSchedulerType.StreamK plan.construct(td) # Ensure that all returned tile descriptions are unique diff --git a/test/python/interface/utils.py b/test/python/interface/utils.py new file mode 100644 index 0000000000..b7050d6c28 --- /dev/null +++ b/test/python/interface/utils.py @@ -0,0 +1,65 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Helper functions & classes for interface test +""" +class ExpectException: + """ + Utility class to assert that an exception was raised when expected + + Example: + + .. highlight:: python + .. code-block:: python + + with ExceptionExpected(True, 'Division by zero'): + x = 1.0 / 0.0 + + :param exception_expected: whether an exception is expected to be raised + :type exception_expected: bool + :param message: message to print if an exception is raised when not expected or vice versa + :type message: str + """ + def __init__(self, exception_expected: bool, message: str = ''): + self.exception_expected = exception_expected + self.message = message + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + exception_raised = exc_type is not None + assert self.exception_expected == exception_raised, self.message + + # Suppress the exception + return True diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 48a55d333b..d102667f95 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -63,7 +63,7 @@ set(CUTLASS_TEST_UNIT_RESULTS_CACHE_DIR ${CMAKE_CURRENT_LIST_DIR}/data/hashes) function(cutlass_test_unit_add_executable NAME) - set(options) + set(options WITHOUT_CUDA) set(oneValueArgs) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -77,13 +77,22 @@ function(cutlass_test_unit_add_executable NAME) PRIVATE ${CUTLASS_UNIT_TEST_COMMON_DIR} ) - - target_link_libraries( - ${NAME} - PRIVATE - cutlass_test_unit_infra - cutlass_test_unit_infra_lib - ) + if (__WITHOUT_CUDA) + # Avoid CUDA dependencies for host-only unit tests that provide the + # WITHOUT_CUDA argument. + target_link_libraries( + ${NAME} + PUBLIC + gtest + ) + else() + target_link_libraries( + ${NAME} + PRIVATE + cutlass_test_unit_infra + cutlass_test_unit_infra_lib + ) + endif() if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND) target_link_libraries(${NAME} PRIVATE OpenMP::OpenMP_CXX) diff --git a/test/unit/conv/CMakeLists.txt b/test/unit/conv/CMakeLists.txt index 8c84322db7..991816581d 100644 --- a/test/unit/conv/CMakeLists.txt +++ b/test/unit/conv/CMakeLists.txt @@ -46,3 +46,4 @@ foreach(SUBDIR add_dependencies(test_unit_conv test_unit_conv_${SUBDIR}) endforeach() + diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index 11671dee1e..38f8ac94ec 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -184,6 +184,7 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) # Conv2d (Strided Dgrad) conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu conv2d_strided_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.cu + conv2d_strided_dgrad_implicit_gemm_swizzling4_sm80.cu # Conv3d conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu diff --git a/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_swizzling4_sm80.cu b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_swizzling4_sm80.cu new file mode 100644 index 0000000000..6f93fec2b6 --- /dev/null +++ b/test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_swizzling4_sm80.cu @@ -0,0 +1,99 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface with swizzling functor > 1 +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) +TEST(SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_swizzle4, + 128x64_32x3_64x32x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + + /// Device-level Conv2d instance + using Conv2dDgradKernel = typename cutlass::conv::kernel::DefaultConv2dDgrad< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 2, + ElementAccumulator, + ElementCompute + >, + cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>, + 3, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, + 8, + 2 + >::Kernel; + + using Conv2dDgrad = cutlass::conv::device::ImplicitGemmConvolution; + + + test::conv::device::Conv2dProblemVector problem_size_list; + + + // run specific problem size in the unit test first + problem_size_list.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 56, 98}, // input size (NHWC) + {128, 3, 3, 98}, // filter size (KRSC) + {4, 0, 5, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d(problem_size_list)); +} + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/core/float8.cu b/test/unit/core/float8.cu index b685838ad1..79805031d3 100644 --- a/test/unit/core/float8.cu +++ b/test/unit/core/float8.cu @@ -35,44 +35,55 @@ #include "../common/cutlass_unit_test.h" #include "cutlass/numeric_types.h" +#include ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(float_e4m3_t, host_conversion) { + using FP8 = cutlass::float_e4m3_t; + using Base = typename FP8::Base; + const float max_abs_normal_val = 448.0f; // 0.1111.110 + const float min_abs_subnormal_val = 0.001953125f; // 0.0000.001 + for (int i = -8; i < 8; ++i) { float f = static_cast(i); - cutlass::float_e4m3_t x = static_cast(i); - cutlass::float_e4m3_t y = static_cast(f); + FP8 x = static_cast(i); + FP8 y = static_cast(f); EXPECT_TRUE(static_cast(x) == i); EXPECT_TRUE(static_cast(y) == f); } // Try out default-ctor (zero initialization of primitive proxy type) - EXPECT_TRUE(cutlass::float_e4m3_t() == 0.0_fe4m3); + EXPECT_TRUE(FP8() == 0.0_fe4m3); // Try out user-defined literals - EXPECT_TRUE(cutlass::float_e4m3_t(7) == 7_fe4m3); + EXPECT_TRUE(FP8(7) == 7_fe4m3); EXPECT_TRUE(7 == static_cast(7_fe4m3)); } TEST(float_e5m2_t, host_conversion) { + using FP8 = cutlass::float_e5m2_t; + using Base = typename FP8::Base; + const float max_abs_normal_val = 57344.0f; // 0.11110.11 + const float min_abs_subnormal_val = 0.0000152588f; // 0.00000.01 + for (int i = -8; i < 8; ++i) { float f = static_cast(i); - cutlass::float_e5m2_t x = static_cast(i); - cutlass::float_e5m2_t y = static_cast(f); + FP8 x = static_cast(i); + FP8 y = static_cast(f); EXPECT_TRUE(static_cast(x) == i); EXPECT_TRUE(static_cast(y) == f); } // Try out default-ctor (zero initialization of primitive proxy type) - EXPECT_TRUE(cutlass::float_e5m2_t() == 0.0_fe5m2); + EXPECT_TRUE(FP8() == 0.0_fe5m2); // Try out user-defined literals - EXPECT_TRUE(cutlass::float_e5m2_t(7) == 7_fe5m2); + EXPECT_TRUE(FP8(7) == 7_fe5m2); EXPECT_TRUE(7 == static_cast(7_fe5m2)); } diff --git a/test/unit/cute/core/CMakeLists.txt b/test/unit/cute/core/CMakeLists.txt index 0a2006dce5..30c05d245a 100644 --- a/test/unit/cute/core/CMakeLists.txt +++ b/test/unit/cute/core/CMakeLists.txt @@ -28,7 +28,7 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_core - + WITHOUT_CUDA array_subbyte.cpp bitfield.cpp coalesce.cpp @@ -36,11 +36,15 @@ cutlass_test_unit_add_executable( compare.cpp complement.cpp composition.cpp + core_unit.cpp inverse_left.cpp inverse_right.cpp logical_divide.cpp logical_product.cpp mixedbits.cpp + nullspace.cpp + pointer.cpp + reverse.cpp transform.cpp tuple.cpp ) diff --git a/test/unit/cute/core/array_subbyte.cpp b/test/unit/cute/core/array_subbyte.cpp index 0667e8e36a..37b6c94014 100644 --- a/test/unit/cute/core/array_subbyte.cpp +++ b/test/unit/cute/core/array_subbyte.cpp @@ -36,10 +36,27 @@ #include #include +#include TEST(CuTe_core, ArraySubbyte) { using namespace cute; + { + array_subbyte array0; + array_subbyte array1; + fill(array0, int4_t(0)); + fill(array1, int4_t(1)); + + for (int i = 0; i < array1.size(); ++i) { + array0[i+5] = array1[i]; + } + + EXPECT_EQ(int4_t(array0.back()), int4_t(1)); + + for (int i = 0; i < array1.size(); ++i) { + EXPECT_EQ(int4_t(array0[i]), int4_t(i / 5)); + } + } { array_subbyte a; @@ -112,3 +129,115 @@ TEST(CuTe_core, ArraySubbyte) //std::cout << std::endl; } } + +TEST(CuTe_core, Subbyte_iterator) +{ + using namespace cute; + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, uint8_t(13)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(uint8_t(tensor(i)), 13); + tensor(i) = uint8_t(i); + EXPECT_EQ(a[i], uint8_t(tensor(i))); + } + + } + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, int4_t(-5)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(int4_t(tensor(i)), int4_t(-5)); + tensor(i) = int4_t(i); + EXPECT_EQ(int4_t(a[i]), int4_t(tensor(i))); + } + + } + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, uint2_t(-5)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(uint2_t(tensor(i)), uint2_t(-5)); + tensor(i) = uint2_t(i); + EXPECT_EQ(uint2_t(a[i]), uint2_t(tensor(i))); + } + + } + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, bool(1)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(bool(tensor(i)), bool(1)); + tensor(i) = bool(i % 2); + EXPECT_EQ(a[i], bool(tensor(i))); + } + } +} + +TEST(CuTe_core, Const_subbyte_iterator) +{ + using namespace cute; + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, uint8_t(13)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(uint8_t(tensor(i)), 13); + a[i] = uint8_t(i); + EXPECT_EQ(a[i], uint8_t(tensor(i))); + } + + } + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, int4_t(-5)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(int4_t(tensor(i)), int4_t(-5)); + a[i] = int4_t(i); + EXPECT_EQ(int4_t(a[i]), int4_t(tensor(i))); + } + + } + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, uint2_t(-5)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(uint2_t(tensor(i)), uint2_t(-5)); + a[i] = uint2_t(i); + EXPECT_EQ(uint2_t(a[i]), uint2_t(tensor(i))); + } + + } + + { + array_subbyte a; + auto tensor = make_tensor(subbyte_iterator(a.raw_data()), make_shape(15)); + + fill(a, bool(1)); + for (int i = 0; i < int(a.size()); ++i) { + EXPECT_EQ(bool(tensor(i)), bool(1)); + a[i] = bool(i % 2); + EXPECT_EQ(a[i], bool(tensor(i))); + } + } +} diff --git a/test/unit/cute/core/bitfield.cpp b/test/unit/cute/core/bitfield.cpp index 94b139e385..4899e47a56 100644 --- a/test/unit/cute/core/bitfield.cpp +++ b/test/unit/cute/core/bitfield.cpp @@ -46,12 +46,13 @@ using namespace cute; TEST(CuTe_core, Bitfield) { for_each(make_int_range<1,65>{}, [&](auto NumBits) { - for_each(make_int_range<0,129>{}, [&](auto BitStart) { - - using BF = bit_field; + constexpr auto num_bits = cute::remove_cvref_t::value; + for_each(make_int_range<0, 129>{}, [&](auto BitStart) { + constexpr auto bit_start = cute::remove_cvref_t::value; + using BF = bit_field::value>; #if 0 - printf("bit_field<%d,%d>:\n", decltype(BitStart)::value, decltype(NumBits)::value); + printf("bit_field<%d,%d>:\n", bit_start, num_bits); printf(" value_type_bits : %d\n", BF::value_type_bits); printf(" storage_type_bits: %d\n", BF::storage_type_bits); printf(" N : %d\n", BF::N); @@ -64,7 +65,7 @@ TEST(CuTe_core, Bitfield) #endif // Test - uint64_t v = decltype(NumBits)::value == 64 ? uint64_t(-1) : ((uint64_t(1) << NumBits) - 1); + uint64_t v = num_bits == 64 ? uint64_t(-1) : ((uint64_t(1) << NumBits) - 1); BF bf{}; bf = v; @@ -74,7 +75,7 @@ TEST(CuTe_core, Bitfield) for_each(make_int_range<0,129>{}, [&](auto BitStart) { - using BF = bit_field; + using BF = bit_field::value, 32, float>; BF bf{}; bf = 3.14f; diff --git a/test/unit/cute/core/core_unit.cpp b/test/unit/cute/core/core_unit.cpp new file mode 100644 index 0000000000..7cf7587b4c --- /dev/null +++ b/test/unit/cute/core/core_unit.cpp @@ -0,0 +1,40 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** \file + \brief Unit tests for CuTe core +*/ + +#include + +int main(int argc, char* arg[]) { + ::testing::InitGoogleTest(&argc, arg); + return RUN_ALL_TESTS(); +} diff --git a/test/unit/cute/core/inverse_right.cpp b/test/unit/cute/core/inverse_right.cpp index 4bb5870bff..4d1501476e 100644 --- a/test/unit/cute/core/inverse_right.cpp +++ b/test/unit/cute/core/inverse_right.cpp @@ -38,6 +38,16 @@ using namespace cute; +template +void +test_postconditions(Layout const& layout, InvLayout const& inv_layout) +{ + for (int i = 0; i < size(inv_layout); ++i) { + //printf("%3d: %3d %3d\n", i, int(inv_layout(i)), int(layout(inv_layout(i)))); + EXPECT_EQ(layout(inv_layout(i)), i); + } +} + template void test_right_inverse(Layout const& layout) @@ -47,10 +57,7 @@ test_right_inverse(Layout const& layout) CUTLASS_TRACE_HOST(layout << " ^ -1\n" << " => \n" << inv_layout); CUTLASS_TRACE_HOST("Composition: " << coalesce(composition(layout, inv_layout)) << std::endl); - for (int i = 0; i < size(inv_layout); ++i) { - //printf("%3d: %3d %3d\n", i, int(inv_layout(i)), int(layout(inv_layout(i)))); - EXPECT_EQ(layout(inv_layout(i)), i); - } + test_postconditions(layout, inv_layout); } TEST(CuTe_core, Inverse_right) diff --git a/test/unit/cute/core/mixedbits.cpp b/test/unit/cute/core/mixedbits.cpp index 55027ebd24..4fd7965022 100644 --- a/test/unit/cute/core/mixedbits.cpp +++ b/test/unit/cute/core/mixedbits.cpp @@ -55,11 +55,11 @@ TEST(CuTe_core, MixedBits) { auto m0 = make_mixed_bits(S0, d0, F0); auto m1 = make_mixed_bits(S1, d1, F1); //print(m0); print(" & "); print(m1); print(" = "); print(m0 & m1); print("\n"); - EXPECT_EQ(to_integral(m0 & m1), to_integral(m0) & to_integral(m1)); + EXPECT_EQ(uint32_t(m0 & m1), uint32_t(m0) & uint32_t(m1)); //print(m0); print(" | "); print(m1); print(" = "); print(m0 | m1); print("\n"); - EXPECT_EQ(to_integral(m0 | m1), to_integral(m0) | to_integral(m1)); + EXPECT_EQ(uint32_t(m0 | m1), uint32_t(m0) | uint32_t(m1)); //print(m0); print(" ^ "); print(m1); print(" = "); print(m0 ^ m1); print("\n"); - EXPECT_EQ(to_integral(m0 ^ m1), to_integral(m0) ^ to_integral(m1)); + EXPECT_EQ(uint32_t(m0 ^ m1), uint32_t(m0) ^ uint32_t(m1)); } } } diff --git a/test/unit/cute/core/nullspace.cpp b/test/unit/cute/core/nullspace.cpp new file mode 100644 index 0000000000..a240780fca --- /dev/null +++ b/test/unit/cute/core/nullspace.cpp @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +#include + +using namespace cute; + +template +void +test_postconditions(Layout const& layout, KerLayout const& ker_layout) +{ + EXPECT_EQ(size(ker_layout), size(layout) / size(filter(layout))); + + for (int i = 0; i < size(ker_layout); ++i) { + //printf("%3d: %3d %3d\n", i, int(ker_layout(i)), int(layout(ker_layout(i)))); + EXPECT_EQ(layout(ker_layout(i)), 0); + } +} + +template +void +test_nullspace(Layout const& layout) +{ + auto ker_layout = nullspace(layout); + + CUTLASS_TRACE_HOST("ker(" << layout << ")\n" << " => \n" << ker_layout); + CUTLASS_TRACE_HOST("Composition: " << coalesce(composition(layout, ker_layout)) << std::endl); + + test_postconditions(layout, ker_layout); +} + +TEST(CuTe_core, Layout_nullspace) +{ + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("NULLSPACE" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto layout = Layout,Stride<_0,_0,_0>>{}; + + test_nullspace(layout); + } + + { + auto layout = Layout,Stride<_0,_0,_0>>{}; + + test_nullspace(layout); + } + + { + auto layout = Layout,Stride<_1,_0,_2>>{}; + + test_nullspace(layout); + } + + { + auto layout = Layout,Stride<_3,_1,_0>>{}; + + test_nullspace(layout); + } +} diff --git a/test/unit/cute/core/pointer.cpp b/test/unit/cute/core/pointer.cpp new file mode 100644 index 0000000000..26ccb8723f --- /dev/null +++ b/test/unit/cute/core/pointer.cpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include + +TEST(CuTe_core, Pointer) +{ + using namespace cute; + + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("CuTe pointer wrappers"); + CUTLASS_TRACE_HOST("-------------------------------"); + + // Test T* overloads (T can be nonconst or const) + { + using T = float; + using expected_type = cute::gmem_ptr; + T* p = nullptr; + + // explicit template argument + auto gmem_p0 = cute::make_gmem_ptr(p); + static_assert(cute::is_same_v); + + // deduced template argument + auto gmem_p1 = cute::make_gmem_ptr(p); + static_assert(cute::is_same_v); + } + { + using T = float const; + using expected_type = cute::gmem_ptr; + T* p = nullptr; + + // explicit template argument + auto gmem_p0 = cute::make_gmem_ptr(p); + static_assert(cute::is_same_v); + + // deduced template argument + auto gmem_p1 = cute::make_gmem_ptr(p); + static_assert(cute::is_same_v); + } + + // Test void* and void const* overloads + // (these require an explicit template argument) + { + using T = float; + using expected_type = cute::gmem_ptr; + void* p = nullptr; + + auto gmem_p0 = cute::make_gmem_ptr(p); + static_assert(cute::is_same_v); + } + { + using T = float const; + using expected_type = cute::gmem_ptr; + void const* p = nullptr; + + auto gmem_p0 = cute::make_gmem_ptr(p); + static_assert(cute::is_same_v); + } + + // Test nullptr_t overload. + { + using T = float; + using expected_type = cute::gmem_ptr; + + auto gmem_p0 = cute::make_gmem_ptr(nullptr); + static_assert(cute::is_same_v); + } + { + using T = float const; + using expected_type = cute::gmem_ptr; + + auto gmem_p0 = cute::make_gmem_ptr(nullptr); + static_assert(cute::is_same_v); + } +} diff --git a/test/unit/cute/core/reverse.cpp b/test/unit/cute/core/reverse.cpp new file mode 100644 index 0000000000..b7ffd273c1 --- /dev/null +++ b/test/unit/cute/core/reverse.cpp @@ -0,0 +1,137 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" +#include "cutlass/trace.h" + +#include "cute/algorithm/tuple_algorithms.hpp" +#include "cute/container/array.hpp" +#include "cute/container/tuple.hpp" + +TEST(CuTe_core, Reverse_Tuple) +{ + using cute::get; + + { + const auto t = cute::make_tuple(); + [[maybe_unused]] auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 0); + } + + { + const auto t = cute::make_tuple(123); + [[maybe_unused]] auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 1); + EXPECT_EQ(get<0>(t_r), 123); + } + + { + const auto t = cute::make_tuple(123, 456); + [[maybe_unused]] auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 2); + EXPECT_EQ(get<0>(t_r), 456); + EXPECT_EQ(get<1>(t_r), 123); + } + + { + const auto t = cute::make_tuple(1, 2, 3, 4, 5); + auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 5); + + EXPECT_EQ(get<0>(t_r), 5); + EXPECT_EQ(get<1>(t_r), 4); + EXPECT_EQ(get<2>(t_r), 3); + EXPECT_EQ(get<3>(t_r), 2); + EXPECT_EQ(get<4>(t_r), 1); + } + + { + const auto t = cute::make_tuple(cute::Int<1>{}, cute::Int<2>{}, 3); + auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 3); + static_assert(cute::is_same_v(t_r))>, int>); + static_assert(cute::is_same_v(t_r))>, cute::Int<2>>); + static_assert(cute::is_same_v(t_r))>, cute::Int<1>>); + + EXPECT_EQ(get<0>(t_r), 3); + EXPECT_EQ(get<1>(t_r), cute::Int<2>{}); + EXPECT_EQ(get<2>(t_r), cute::Int<1>{}); + } +} + +TEST(CuTe_core, Reverse_Array) +{ + using cute::get; + + { + const auto t = cute::array{}; + [[maybe_unused]] auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 0); + + using reverse_type = cute::array; + static_assert(cute::is_same_v); + } + + { + const auto t = cute::array{123}; + [[maybe_unused]] auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 1); + EXPECT_EQ(get<0>(t_r), 123); + + using reverse_type = cute::array; + static_assert(cute::is_same_v); + } + + { + const auto t = cute::array{123, 456}; + [[maybe_unused]] auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 2); + EXPECT_EQ(get<0>(t_r), 456); + EXPECT_EQ(get<1>(t_r), 123); + + using reverse_type = cute::array; + static_assert(cute::is_same_v); + } + + { + const auto t = cute::array{1.125f, 2.25f, 3.5f, 4.625f, 5.75f}; + auto t_r = cute::reverse(t); + static_assert(cute::tuple_size_v == 5); + EXPECT_EQ(get<0>(t_r), 5.75f); + EXPECT_EQ(get<1>(t_r), 4.625f); + EXPECT_EQ(get<2>(t_r), 3.5f); + EXPECT_EQ(get<3>(t_r), 2.25f); + EXPECT_EQ(get<4>(t_r), 1.125f); + + using reverse_type = cute::array; + static_assert(cute::is_same_v); + } +} diff --git a/test/unit/cute/hopper/tma_load.cu b/test/unit/cute/hopper/tma_load.cu index ddb95c3c7b..2b1eb94f79 100644 --- a/test/unit/cute/hopper/tma_load.cu +++ b/test/unit/cute/hopper/tma_load.cu @@ -119,7 +119,7 @@ tma_test_device_cute(T const* g_in, T* g_out, for (int stage = 0; stage < size<1>(tAgA); ++stage) { // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = size(sA) * sizeof(T); + constexpr int kTmaTransactionBytes = size(sA) * sizeof_bits_v / 8; if (threadIdx.x == 0) { @@ -140,6 +140,10 @@ tma_test_device_cute(T const* g_in, T* g_out, // Write out trivially smem -> gmem // + //if (thread0()) { + // print_tensor(sA); + //} + for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { tBgB(i,stage) = sA(i); } @@ -154,14 +158,13 @@ test_tma_load(GMEM_Layout const& gmem_layout, CTA_Tile const& cta_tile) { thrust::host_vector h_in(cosize(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); } thrust::device_vector d_in = h_in; thrust::device_vector d_out(h_in.size(), T(-1)); Tensor gA = make_tensor(d_in.data().get(), gmem_layout); auto tma = make_tma_copy(SM90_TMA_LOAD{}, gA, smem_layout, cta_tile, Int<1>{}); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - //print("TMA Instr size: "); print(decltype(tma)::NumValSrc); print("\n"); + //print(tma); int smem_size = int(sizeof(SharedStorage)); tma_test_device_cute<<<1, 128, smem_size>>>( @@ -187,6 +190,26 @@ test_tma_load(GMEM_Layout const& gmem_layout, return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); } +TEST(SM90_CuTe_Hopper, Tma_Load_1D) +{ + Layout smem_layout = Layout<_256, _1>{}; + { + Layout gmem_layout = smem_layout; + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + } + + { + Layout gmem_layout = make_layout(128, GenColMajor{}); + test_tma_load(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + test_tma_load< float>(gmem_layout, smem_layout); + test_tma_load(gmem_layout, smem_layout); + } +} + TEST(SM90_CuTe_Hopper, Tma_Load_32x32_Col) { Layout smem_layout = Layout, Stride<_1,_32>>{}; @@ -343,45 +366,20 @@ TEST(SM90_CuTe_Hopper, Tma_Load_Swizzle_Tiles) test_tma_load_swizzle_tile_k(); } - -TEST(SM90_CuTe_Hopper, Tma_Load_Metamode) +// Tensor by-mode +TEST(SM90_CuTe_Hopper, Tma_Load_Tensor) { + // 3-mode TMA { - auto smem_layout = Layout, Stride<_1,_32>>{}; - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); - test_tma_load(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenColMajor{}); - test_tma_load(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenColMajor{}); - test_tma_load(gmem_layout, smem_layout); - } - } - - { - auto smem_layout = Layout, Stride<_32,_1>>{}; - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); - test_tma_load(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenRowMajor{}); - test_tma_load(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenRowMajor{}); - test_tma_load(gmem_layout, smem_layout); - } + Layout gmem_layout = make_layout(make_shape(128, 64, 5)); + auto cta_tile = Shape<_64, _32>{}; // GMEM Tiling: + // Take 64-elem from m + // Take 32-elem from k + auto smem_layout = make_layout(Shape<_64,_32>{}); + test_tma_load(gmem_layout, smem_layout, cta_tile); } -} -TEST(SM90_CuTe_Hopper, Tma_Load_Tensor) -{ - // Tensor by-mode + // 4-mode TMA { Layout gmem_layout = make_layout(make_shape(make_shape(80,40),make_shape(32,12))); auto cta_tile = Shape,Shape<_32,_2>>{}; // GMEM Tiling: @@ -391,18 +389,20 @@ TEST(SM90_CuTe_Hopper, Tma_Load_Tensor) test_tma_load(gmem_layout, smem_layout, cta_tile); } - // Tensor Metamode -- Tiler selects flat elements from a multimode + // 5-mode TMA { - Layout gmem_layout = make_layout(make_shape(make_shape(32,40),make_shape(make_shape(8,8),12))); - auto cta_tile = Shape<_128, Shape<_32,_2>>{}; // GMEM Tiling: - // Take 128-elem from m: m0 must divide 128, - // m-last may be predicated + Layout gmem_layout = make_layout(make_shape(make_shape(32,32,32),make_shape(32,12))); + auto cta_tile = Shape,Shape<_16,_2>>{}; // GMEM Tiling: + // Take 4-elem from m0, 4-elem from m1, 5-elem from m2 // Take 32-elem from k0, 2-elem from k1 - auto smem_layout = make_layout(Shape<_128,_64>{}); + auto smem_layout = make_layout(Shape<_128,_32>{}); test_tma_load(gmem_layout, smem_layout, cta_tile); } +} - // Tensor Multimode -- TMA with more than 5 modes in GMEM (packs residual modes into last TMA mode) +// Tensor Multimode -- TMA with more than 5 modes in GMEM (packs residual modes into last TMA mode) +TEST(SM90_CuTe_Hopper, Tma_Load_Tensor_Multimode) +{ { Layout gmem_layout = make_layout(make_shape(make_shape(32,3,2,2),make_shape(32,4,2))); auto cta_tile = Shape, Shape<_32,_2>>{}; // GMEM Tiling: @@ -412,6 +412,23 @@ TEST(SM90_CuTe_Hopper, Tma_Load_Tensor) test_tma_load(gmem_layout, smem_layout, cta_tile); } + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,3,2,2),make_shape(32,4,2))); + auto cta_tile = Shape, Shape<_32,_2>>{}; // GMEM Tiling: + // Take 32-elem from m0, 3-elem from m1 + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_96,_64>{}); + test_tma_load(gmem_layout, smem_layout, cta_tile); + } + + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,3,2,3,2),make_shape(32,4,2,2))); + auto cta_tile = Shape, Shape<_16,_2>>{}; // GMEM Tiling: + // Take 32-elem from m0 + // Take 16-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_32,_32>{}); + test_tma_load(gmem_layout, smem_layout, cta_tile); + } } #endif diff --git a/test/unit/cute/hopper/tma_store.cu b/test/unit/cute/hopper/tma_store.cu index 4d96070a67..7a8c49ac83 100644 --- a/test/unit/cute/hopper/tma_store.cu +++ b/test/unit/cute/hopper/tma_store.cu @@ -145,14 +145,13 @@ test_tma_store(GMEM_Layout const& gmem_layout, CTA_Tile const& cta_tile) { thrust::host_vector h_in(cosize(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i); } + for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); } thrust::device_vector d_in = h_in; thrust::device_vector d_out(h_in.size(), T(-1)); Tensor gA = make_tensor(d_out.data().get(), gmem_layout); auto tma = make_tma_copy(SM90_TMA_STORE{}, gA, smem_layout, cta_tile, Int<1>{}); - //print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); - //print("TMA Instr size: "); print(decltype(tma)::NumValSrc); print("\n"); + //print(tma); int smem_size = int(sizeof(SharedStorage)); tma_test_device_cute<<<1, 128, smem_size>>>( @@ -178,6 +177,26 @@ test_tma_store(GMEM_Layout const& gmem_layout, return test_tma_store(gmem_layout, smem_layout, product_each(shape(smem_layout))); } +TEST(SM90_CuTe_Hopper, Tma_Load_1D) +{ + Layout smem_layout = Layout<_256, _1>{}; + { + Layout gmem_layout = smem_layout; + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + } + + { + Layout gmem_layout = make_layout(128, GenColMajor{}); + test_tma_store(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + test_tma_store< float>(gmem_layout, smem_layout); + test_tma_store(gmem_layout, smem_layout); + } +} + TEST(SM90_CuTe_Hopper, Tma_Store_32x32_Col) { Layout smem_layout = Layout, Stride<_1,_32>>{}; @@ -335,44 +354,20 @@ TEST(SM90_CuTe_Hopper, Tma_Store_Swizzle_Tiles) } -TEST(SM90_CuTe_Hopper, Tma_Store_Metamode) +// Tensor by-mode +TEST(SM90_CuTe_Hopper, Tma_Store_Tensor) { + // 3-mode TMA { - auto smem_layout = Layout, Stride<_1,_32>>{}; - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenColMajor{}); - test_tma_store(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenColMajor{}); - test_tma_store(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenColMajor{}); - test_tma_store(gmem_layout, smem_layout); - } - } - - { - auto smem_layout = Layout, Stride<_32,_1>>{}; - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,4), 32), GenRowMajor{}); - test_tma_store(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(8,32), 32), GenRowMajor{}); - test_tma_store(gmem_layout, smem_layout); - } - { - Layout gmem_layout = make_layout(make_shape(make_shape(64,32), 32), GenRowMajor{}); - test_tma_store(gmem_layout, smem_layout); - } + Layout gmem_layout = make_layout(make_shape(128, 64, 5)); + auto cta_tile = Shape<_64, _32>{}; // GMEM Tiling: + // Take 64-elem from m + // Take 32-elem from k + auto smem_layout = make_layout(Shape<_64,_32>{}); + test_tma_store(gmem_layout, smem_layout, cta_tile); } -} -TEST(SM90_CuTe_Hopper, Tma_Store_Tensor) -{ - // Tensor by-mode + // 4-mode TMA { Layout gmem_layout = make_layout(make_shape(make_shape(80,40),make_shape(32,12))); auto cta_tile = Shape,Shape<_32,_2>>{}; // GMEM Tiling: @@ -382,18 +377,20 @@ TEST(SM90_CuTe_Hopper, Tma_Store_Tensor) test_tma_store(gmem_layout, smem_layout, cta_tile); } - // Tensor Metamode -- Tiler selects flat elements from a multimode + // 5-mode TMA { - Layout gmem_layout = make_layout(make_shape(make_shape(32,40),make_shape(make_shape(8,8),12))); - auto cta_tile = Shape<_128, Shape<_32,_2>>{}; // GMEM Tiling: - // Take 128-elem from m: m0 must divide 128, - // m-last may be predicated + Layout gmem_layout = make_layout(make_shape(make_shape(32,32,32),make_shape(32,12))); + auto cta_tile = Shape,Shape<_16,_2>>{}; // GMEM Tiling: + // Take 4-elem from m0, 4-elem from m1, 5-elem from m2 // Take 32-elem from k0, 2-elem from k1 - auto smem_layout = make_layout(Shape<_128,_64>{}); + auto smem_layout = make_layout(Shape<_128,_32>{}); test_tma_store(gmem_layout, smem_layout, cta_tile); } +} - // Tensor Multimode -- TMA with more than 5 modes in GMEM (packs residual modes into last TMA mode) +// Tensor Multimode -- TMA with more than 5 modes in GMEM (packs residual modes into last TMA mode) +TEST(SM90_CuTe_Hopper, Tma_Store_Tensor_Multimode) +{ { Layout gmem_layout = make_layout(make_shape(make_shape(32,3,2,2),make_shape(32,4,2))); auto cta_tile = Shape, Shape<_32,_2>>{}; // GMEM Tiling: @@ -403,6 +400,23 @@ TEST(SM90_CuTe_Hopper, Tma_Store_Tensor) test_tma_store(gmem_layout, smem_layout, cta_tile); } + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,3,2,2),make_shape(32,4,2))); + auto cta_tile = Shape, Shape<_32,_2>>{}; // GMEM Tiling: + // Take 32-elem from m0, 3-elem from m1 + // Take 32-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_96,_64>{}); + test_tma_store(gmem_layout, smem_layout, cta_tile); + } + + { + Layout gmem_layout = make_layout(make_shape(make_shape(64,3,2,3,2),make_shape(32,4,2,2))); + auto cta_tile = Shape, Shape<_16,_2>>{}; // GMEM Tiling: + // Take 32-elem from m0 + // Take 16-elem from k0, 2-elem from k1 + auto smem_layout = make_layout(Shape<_32,_32>{}); + test_tma_store(gmem_layout, smem_layout, cta_tile); + } } #endif diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 717dbd5bd0..56deefdb0f 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -88,7 +88,7 @@ cutlass_test_unit_add_executable( simt_cgemm_nt_sm50.cu simt_cgemm_tn_sm50.cu simt_cgemm_tt_sm50.cu - + simt_qgemm_nn_sm50.cu simt_qgemm_nt_sm50.cu simt_qgemm_tn_sm50.cu @@ -98,30 +98,30 @@ cutlass_test_unit_add_executable( simt_dgemm_nt_sm50.cu simt_dgemm_tn_sm50.cu simt_dgemm_tt_sm50.cu - + simt_hgemm_nn_sm50.cu simt_hgemm_nt_sm50.cu simt_hgemm_tn_sm50.cu simt_hgemm_tt_sm50.cu - + simt_igemm_nn_sm50.cu simt_igemm_nt_sm50.cu simt_igemm_tn_sm50.cu simt_igemm_tt_sm50.cu - + simt_int8_igemm_sm61_sliced_k.cu simt_int8_igemm_sm61.cu - + simt_sgemm_nn_sm50.cu simt_sgemm_nt_sm50.cu simt_sgemm_tn_sm50.cu simt_sgemm_tt_sm50.cu - + simt_zgemm_nn_sm50.cu simt_zgemm_nt_sm50.cu simt_zgemm_tn_sm50.cu simt_zgemm_tt_sm50.cu - + gemm_splitk_simt_sm50.cu ) @@ -193,7 +193,7 @@ cutlass_test_unit_add_executable( gemm_s4n_s4t_s4n_tensor_op_s32_sm75.cu gemm_s4t_s4n_s32t_tensor_op_s32_sm75.cu gemm_s4t_s4n_s32n_tensor_op_s32_sm75.cu - gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu + gemm_s4t_s4n_s4n_tensor_op_s32_sm75.cu gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu @@ -253,6 +253,17 @@ cutlass_test_unit_add_executable( sm90_gemm_s8_s8_s8_tensor_op_s32.cu sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu sm90_gemm_f32_f32_f32_tensor_op_f32.cu + sm90_gemm_f8_f8_f32_tensor_op_fp32.cu + sm90_gemm_f8_f8_bf16_tensor_op_fp32.cu + sm90_gemm_f8_f8_f8_tensor_op_fp32.cu +) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_tensorop_sm90_stream_k + + sm90_gemm_stream_k_scheduler.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu + sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu ) # Alignment tests @@ -278,9 +289,16 @@ cutlass_test_unit_add_executable( sm90_gemm_s8_s8_s8_tensor_op_s32_tensor_broadcast.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_load.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_aux_load.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_row_broadcast.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_row_broadcast.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_reduce.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_reduce.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu + sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu ) - cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 @@ -291,6 +309,7 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu + sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu ) cutlass_test_unit_add_executable( @@ -311,7 +330,7 @@ cutlass_test_unit_add_executable( gemm_tf32t_tf32n_f32t_tensor_op_f32_sm80.cu gemm_tf32n_tf32t_f32t_tensor_op_f32_sm80.cu gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu - gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu + gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu gemm_universal_cf32n_cf32n_cf32n_tensor_op_f32_sm80.cu gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu gemm_cf32t_cf32n_cf32t_tensor_op_tf32_f32_sm80.cu @@ -331,7 +350,7 @@ cutlass_test_unit_add_executable( gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu gemm_f64t_f64n_f64t_tensor_op_f64_sm80.cu - gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu + gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu gemm_universal_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm80.cu @@ -363,7 +382,7 @@ cutlass_test_unit_add_executable( gemm_s8t_s8n_f16t_tensor_op_s32_sm80.cu gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu gemm_s4t_s4n_s32t_tensor_op_s32_sm80.cu - gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu + gemm_s4t_s4n_s4n_tensor_op_s32_sm80.cu gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu @@ -436,8 +455,8 @@ cutlass_test_unit_add_executable( BATCH_SOURCES ON BATCH_SIZE 4 - gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu - gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu + gemm_planar_complex_f16_f16_f32_tensor_op_sm70.cu + gemm_planar_complex_f16_f16_f32_tensor_op_sm75.cu gemm_planar_complex_f16_f16_f32_tensor_op_sm80.cu ) cutlass_test_unit_add_executable( @@ -512,7 +531,7 @@ add_dependencies( cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_gemm_with_fused_epilogue_tensorop - + gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu @@ -541,8 +560,8 @@ cutlass_test_unit_add_executable( ## SYRK # Syrk SM80 f64 tests - syrk_f64n_f64t_tensor_op_f64_sm80.cu - syrk_f64t_f64n_tensor_op_f64_sm80.cu + syrk_f64n_f64t_tensor_op_f64_sm80.cu + syrk_f64t_f64n_tensor_op_f64_sm80.cu # Syrk SM80 f32 tests syrk_tf32n_f32t_tensor_op_f32_sm80.cu @@ -567,7 +586,7 @@ cutlass_test_unit_add_executable( # Syrk SM90 complex f64 tests syrk_cf64_cf64_tensor_op_f64_sm90.cu - ## HERK + ## HERK # Herk SM80 complex f64 tests herk_cf64h_cf64n_tensor_op_f64_sm80.cu @@ -689,7 +708,7 @@ cutlass_test_unit_add_executable( hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_sm80.cu hemm_cf64h_cf64n_cf64n_tensor_op_rs_f64_sm80.cu hemm_cf64h_cf64n_cf64n_tensor_op_ls_f64_gaussian_sm80.cu - + # Hemm SM80 complex f32 tests hemm_cf32h_cf32n_tensor_op_f32_ls_sm80.cu hemm_cf32h_cf32n_tensor_op_f32_rs_sm80.cu diff --git a/test/unit/gemm/device/default_gemm_configuration.hpp b/test/unit/gemm/device/default_gemm_configuration.hpp index adfb9eda43..96d7894681 100644 --- a/test/unit/gemm/device/default_gemm_configuration.hpp +++ b/test/unit/gemm/device/default_gemm_configuration.hpp @@ -1046,7 +1046,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // A (M,K) K-Major using SmemLayoutAtomA = decltype( - composition(SwizzleXor<2,0,2>{}, + composition(Swizzle<2,0,4>{}, Layout, Stride<_1, _4>>{})); // M, K using SmemCopyAtomA = Copy_Atom; @@ -1059,7 +1059,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // B (N,K) K-Major using SmemLayoutAtomB = decltype( - composition(SwizzleXor<2,0,2>{}, + composition(Swizzle<2,0,4>{}, Layout, Stride<_1, _4>>{})); // N, K using SmemCopyAtomB = Copy_Atom; @@ -1124,7 +1124,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // A (M,K) M-Major using SmemLayoutAtomA = decltype( - composition(SwizzleXor<2,2,0>{}, + composition(Swizzle<2,2,2>{}, Layout, Stride< _1,_16>>{})); // M, K using SmemCopyAtomA = Copy_Atom; @@ -1137,7 +1137,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // B (N,K) K-Major using SmemLayoutAtomB = decltype( - composition(SwizzleXor<2,0,2>{}, + composition(Swizzle<2,0,4>{}, Layout, Stride<_1, _4>>{}));// N, K using SmemCopyAtomB = Copy_Atom; @@ -1188,7 +1188,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // A (M,K) M-Major using SmemLayoutAtomA = decltype( - composition(SwizzleXor<2,2,0>{}, + composition(Swizzle<2,2,2>{}, Layout, Stride< _1,_16>>{})); // M, K using SmemCopyAtomA = Copy_Atom; @@ -1201,7 +1201,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // B (N,K) N-Major using SmemLayoutAtomB = decltype( - composition(SwizzleXor<2,2,0>{}, + composition(Swizzle<2,2,2>{}, Layout, Stride< _1,_16>>{})); // N, K using SmemCopyAtomB = Copy_Atom; @@ -1252,7 +1252,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // A (M,K) K-Major using SmemLayoutAtomA = decltype( - composition(SwizzleXor<2,0,2>{}, + composition(Swizzle<2,0,4>{}, Layout, Stride<_1, _4>>{})); // M, K using SmemCopyAtomA = Copy_Atom; @@ -1265,7 +1265,7 @@ struct DefaultGemmConfigurationToCutlass3Types< // B (N,K) N-Major using SmemLayoutAtomB = decltype( - composition(SwizzleXor<2,2,0>{}, + composition(Swizzle<2,2,2>{}, Layout, Stride< _1,_16>>{})); // N, K using SmemCopyAtomB = Copy_Atom; diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu index 45c1d802d9..3f9b31d4e4 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm75.cu @@ -49,7 +49,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM75_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu index 67bcb8573d..c3bb8b1aa4 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_tensor_op_s32_sm80.cu @@ -47,9 +47,329 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3,128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x1024_64x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 1024>, + cutlass::gemm::GemmShape<64, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x1024_32x64x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 1024>, + cutlass::gemm::GemmShape<32, 64, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x1024_64x32x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 1024>, + cutlass::gemm::GemmShape<64, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x1024_32x32x1024) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 1024>, + cutlass::gemm::GemmShape<32, 32, 1024>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x256x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x128x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x128x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 256x64x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x256x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x128x512_32x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 512>, + cutlass::gemm::GemmShape<32, 64, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 128x64x512_64x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 512>, + cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + +TEST(SM80_Device_Gemm_b1t_b1n_s32n_tensor_op_s32, 64x64x512_32x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + cutlass::uint1b_t, cutlass::layout::RowMajor, cutlass::uint1b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 512>, + cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 256>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6, 128, 128, + false, cutlass::arch::OpAndPopc>; + + EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); +} + //////////////////////////////////////////////////////////////////////////////// TEST(SM80_Device_Gemm_XOR_b1t_b1n_s32n_tensor_op_s32, 128x256x1024_64x64x1024) { diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu index 6c8ab54ecb..b804d94f31 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32n_wmma_tensor_op_s32_sm75.cu @@ -49,7 +49,6 @@ #include "cutlass/util/reference/host/gemm.h" #include "testbed.h" - ///////////////////////////////////////////////////////////////////////////////////////////////// ////// WMMA Instruction Shape = 8x8x128, DataType/Instruction = b1 ^ b1 + s32 => s32 ///////// ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -240,4 +239,5 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32n_wmma_tensor_op_s32, 64x64x512_32x32x512_8x8x1 EXPECT_TRUE(test::gemm::device::TestAllGemmBasic()); } + #endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu index 445fa885b1..bfe281fb31 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm75.cu @@ -49,7 +49,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 128x256x512_64x64x512) { @@ -228,5 +227,4 @@ TEST(SM75_Device_Gemm_b1t_b1n_s32t_tensor_op_s32, 64x64x512_32x32x512) { } ///////////////////////////////////////////////////////////////////////////////////////////////// - #endif diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu index c81914802c..41ce86f111 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_tensor_op_s32_sm80.cu @@ -48,7 +48,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) - //////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu b/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu index 755661f2ba..505405e8b1 100644 --- a/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu +++ b/test/unit/gemm/device/gemm_b1t_b1n_s32t_wmma_tensor_op_s32_sm75.cu @@ -49,7 +49,6 @@ #include "cutlass/util/reference/host/gemm.h" #include "testbed.h" - ///////////////////////////////////////////////////////////////////////////////////////////////// ////// WMMA Instruction Shape = 8x8x128, DataType/Instruction = b1 ^ b1 + s32 => s32 ///////// ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu index cc9430350e..e84720476b 100644 --- a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm90.cu @@ -51,7 +51,6 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu index e2931b0203..c30e23087f 100644 --- a/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm90.cu @@ -51,7 +51,6 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Gemm_cf64n_cf64t_cf64t_tensor_op_f64, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu index eb011e4c53..c4deb60bc9 100644 --- a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian_sm90.cu @@ -51,7 +51,6 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64_gaussian, 32x32x8_16x16x8) { diff --git a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu index c0333e7c6a..233e58d1f4 100644 --- a/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_cf64t_cf64n_cf64t_tensor_op_f64_sm90.cu @@ -51,7 +51,6 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 32x32x8_16x16x8) { @@ -298,7 +297,6 @@ TEST(SM90_Device_Gemm_cf64t_cf64n_cf64t_tensor_op_f64, 128x64x16_32x32x16) { } ///////////////////////////////////////////////////////////////////////////////////////////////// - #endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu index e98764e5f0..5ebd2e4be2 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16n_direct_store_tensor_op_f32_sm80.cu @@ -79,7 +79,7 @@ TEST(SM80_Device_GemmUniversal_DirectStore_f16n_f16t_f32n_tensor_op_f32, 128x128 cutlass::gemm::GemmShape<16, 8, 16>, cutlass::epilogue::thread::LinearCombination< ElementOutput, - 4, // This is the vector size of the epilogue. + 4, // This is the vector size of the epilogue. ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu index fc7bf70246..6872a4168d 100644 --- a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu @@ -282,7 +282,7 @@ TEST(SM80_Device_Sparse_Gemm_Row_Broadcast_f16n_f16n_f16t_tensor_op_f32, 64x64x1 ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; - + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm(true)); } diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu index b69a3043be..71121d141f 100644 --- a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_broadcast_sm80.cu @@ -133,7 +133,6 @@ struct TestbedUtils { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu index 62cb15dd77..6d704cc03f 100644 --- a/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu @@ -47,7 +47,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Gemm_f64n_f64t_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { diff --git a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu index 881d81c8b9..211d3bfddd 100644 --- a/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/gemm_f64t_f64n_f64t_tensor_op_f64_sm90.cu @@ -47,7 +47,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Gemm_f64t_f64n_f64t_tensor_op_f64, 32x32x16_16x16x16_16x8x4) { diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 5f19032ef1..30b4264949 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -56,6 +56,7 @@ #include "cutlass/layout/matrix.h" #include "cutlass/matrix_coord.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/fusion/operations.hpp" #include "cute/int_tuple.hpp" @@ -67,6 +68,76 @@ namespace device { namespace detail{ +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !std::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !std::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + template < typename Gemm, template class ActivationFunctor_ = cutlass::epilogue::thread::Identity @@ -83,15 +154,18 @@ struct TestbedImpl { using ElementD = typename Gemm::GemmKernel::ElementD; using StrideD = typename Gemm::GemmKernel::StrideD; using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; - using ElementCompute = typename Gemm::GemmKernel::CollectiveEpilogue::ElementCompute; - using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - using ThreadEpilogueOp = typename Gemm::GemmKernel::CollectiveEpilogue::ThreadEpilogueOp; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + /// For custom EVTs + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; using ActivationFunctor = ActivationFunctor_; static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static constexpr uint32_t mma_promotion_interval = 4; + // Looks at Cute Stride to check Row / Column Major template static constexpr bool is_row_or_col_major(){ @@ -112,10 +186,10 @@ struct TestbedImpl { "ERROR : D Layout is neither Row / Column Major)"); // Deduce Cutlass Layouts (RowMajor & ColumnMajor) - using LayoutTagA = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); - using LayoutTagB = decltype(cutlass::gemm::detail::stride_to_layout_tag_B()); - using LayoutTagC = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); - using LayoutTagD = decltype(cutlass::gemm::detail::stride_to_layout_tag_A()); + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + using LayoutTagC = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagA_t; /// Initialization StrideA stride_a; @@ -241,10 +315,10 @@ struct TestbedImpl { auto K = cute::size<2>(problem_shape_MNKL); auto L = cute::size<3>(problem_shape_MNKL); - stride_a = make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); - stride_b = make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); - stride_c = make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); - stride_d = make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode auto a_coord = cutlass::make_Coord(M * L, K); @@ -347,24 +421,33 @@ struct TestbedImpl { auto D = cute::make_tensor(reference_D.host_data(), cute::make_layout(cute::make_shape(M, N, L), stride_d)); auto Bias = cute::make_tensor(static_cast(nullptr), - cute::make_layout(cute::make_shape(M, 1))); - auto T = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Aux = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto Valpha = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Vbeta = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; cutlass::reference::host::GettEpilogueParams< + ElementScalar, ElementScalar, ElementAccumulator, ElementCompute, decltype(C), decltype(D), decltype(Bias), - decltype(T), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), ActivationFunctor > epilogue_params{ alpha, beta, - C, D, Bias, T + C, D, Bias, Aux + , Valpha, Vbeta }; cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); @@ -446,7 +529,8 @@ struct TestbedImpl { ElementScalar alpha = ElementScalar(1), ElementScalar beta = ElementScalar(0), bool profiling = false, - int iterations = 20) + detail::Iterations iterations = Iterations{}, + detail::Splits splits = Splits{}) { // Fail test if insufficient CUDA device if (!sufficient()) { @@ -472,6 +556,11 @@ struct TestbedImpl { hw_info.sm_count = this->sm_count; } + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (std::is_same_v) { + scheduler_args = { static_cast(splits) }; + } + // DefaultEpilogue arguments = typename Gemm::Arguments{ cutlass::gemm::GemmUniversalMode::kGemm, @@ -484,7 +573,8 @@ struct TestbedImpl { {alpha, beta}, tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d }, - hw_info + hw_info, + scheduler_args }; Gemm gemm_op; @@ -505,7 +595,7 @@ struct TestbedImpl { // if (profiling) { - return profile(problem_size, iterations, gemm_op, arguments, workspace); + return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); } else { cudaError_t result; @@ -550,9 +640,10 @@ struct Testbed3x { using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; - using ElementAccumulator = typename Kernel::ElementAccumulator; - using ElementCompute = typename Epilogue::ElementCompute; - using ElementScalar = typename Epilogue::ElementScalar; + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + using LayoutTagA = typename TestBedImpl::LayoutTagA; using LayoutTagB = typename TestBedImpl::LayoutTagB; using LayoutTagC = typename TestBedImpl::LayoutTagC; @@ -594,25 +685,34 @@ struct Testbed3x { typename TestBedImpl::ProblemShapeType problem_size, ElementScalar alpha = ElementScalar(1), ElementScalar beta = ElementScalar(0), + detail::Splits splits = detail::Splits{}, bool profiling = false, - int iterations = 20) + detail::Iterations iterations = detail::Iterations{}) { return impl_.run( - problem_size, alpha, beta, profiling, iterations + problem_size, alpha, beta, profiling, iterations, splits ); } }; ///////////////////////////////////////////////////////////////////////////////////////////////// -// Testbed for GEMMs with epilogues including a bias operation and an elementwise function +// Testbed for GEMMs with fused epilogues using the fusion::FusionOperation API +// Does not support testing of custom EVTs template -struct Testbed3xBiasElementwise { +struct Testbed3xFusionOperation { using TestBedImpl = typename detail::TestbedImpl; using Kernel = typename Gemm::GemmKernel; using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + using ElementA = typename Kernel::ElementA; using StrideA = typename Kernel::StrideA; using ElementB = typename Kernel::ElementB; @@ -621,34 +721,76 @@ struct Testbed3xBiasElementwise { using StrideC = typename Kernel::StrideC; using ElementD = typename Kernel::ElementD; using StrideD = typename Kernel::StrideD; - - using ElementAccumulator = typename Kernel::ElementAccumulator; - using ElementCompute = typename Epilogue::ElementCompute; - using ProblemShapeType = typename Kernel::ProblemShape; - using ElementBias = typename Epilogue::ElementBias; - using ElementT = typename Epilogue::ElementT; - using ElementScalar = typename Epilogue::ElementScalar; - using ActivationFunctor = typename Epilogue::ActivationFunctor; - using BinaryOp = typename Epilogue::BinaryOp; - - static constexpr bool IsBiasEnabled = Epilogue::iskThreadEpilogueOpWithBias; - static constexpr bool StoreT = Epilogue::StoreT; - - using LayoutTagA = typename TestBedImpl::LayoutTagA; - using LayoutTagB = typename TestBedImpl::LayoutTagB; - using LayoutTagC = typename TestBedImpl::LayoutTagC; - using LayoutTagD = typename TestBedImpl::LayoutTagD; - using LayoutTagVector = cutlass::layout::PackedVectorLayout; - - cutlass::HostTensor bias; - cutlass::HostTensor< ElementT, LayoutTagD> tensor_T; - cutlass::HostTensor< ElementT, LayoutTagD> reference_T; + using ProblemShapeType = typename Kernel::ProblemShape; + using ElementAccumulator = typename Kernel::ElementAccumulator; + + // + // FusionOperation derived types/queries + // + using FusionOp = typename Gemm::EpilogueOutputOp; + static_assert(cute::is_base_of_v); + + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t, + cutlass::epilogue::thread::Identity>; + + static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabled = FusionOp::IsAbsMaxSupported; + + // Legacy support for deprecated bias-elementwise collective, will be removed next release + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool IsLegacy = + cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> + >; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor tensor_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + // References + cutlass::HostTensor reference_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; // Detail Implementation TestBedImpl impl_; // Whether to use relative equality checks - bool check_relative_equality; + bool check_relative_equality = false; + // Are scalars copied to device memory before kernel launch + bool use_device_scalars = false; + // If per-row scale is enabled and this is true, beta is passed as a host scalar instead of device vector + bool disable_vector_beta = false; + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; // Factors used for calculating relative equality. These default // values are borrowed from those used by default in the CUTLASS @@ -659,24 +801,22 @@ struct Testbed3xBiasElementwise { // // Methods // - Testbed3xBiasElementwise( - bool check_relative_equality_, + Testbed3xFusionOperation( + bool check_relative_equality_ = false, + bool use_device_scalars_ = false, + bool disable_vector_beta_ = false, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(check_relative_equality_) { } + ) : impl_(init_A_, init_B_, init_C_, seed_), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_), + init_scale(init_scale_), init_bias(init_bias_) { } - Testbed3xBiasElementwise( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, - uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(false) { } - - Testbed3xBiasElementwise( + Testbed3xFusionOperation( typename LayoutTagA::Stride stride_factor_A_, typename LayoutTagB::Stride stride_factor_B_, typename LayoutTagC::Stride stride_factor_C_, @@ -685,41 +825,94 @@ struct Testbed3xBiasElementwise { cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, uint64_t seed_ = TestBedImpl::kDefaultSeed - ) : - impl_(stride_factor_A_, - stride_factor_B_, - stride_factor_C_, - stride_factor_D_, - init_A_, - init_B_, - init_C_, - seed_), - check_relative_equality(false) { } + ) : impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + init_A_, + init_B_, + init_C_, + seed_) { } /// Initializes data structures - void initialize(ProblemShapeType problem_size) { - // - // Allocate the GEMM workspace for A/B/C/D/T tensor - // + void initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + + // Allocate the GEMM workspace for A/B/C/D tensor impl_.initialize(problem_size); - if constexpr (StoreT) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto [M, N, K, L] = problem_shape_MNKL; - auto c_coord = cutlass::make_Coord(M * L, N); - tensor_T.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_D)); - reference_T.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.stride_factor_D), false); - tensor_T.sync_device(); + if constexpr (IsPerRowScaleEnabled) { + alpha.resize(col_vector_coord); + EXPECT_TRUE(impl_.initialize_tensor(alpha.host_view(), init_scale, impl_.seed + 2023)); + if (disable_vector_beta) { + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + else { + beta.resize(col_vector_coord); + EXPECT_TRUE(impl_.initialize_tensor(beta.host_view(), init_scale, impl_.seed + 2024)); + } + } + else { + alpha.resize(scalar_coord, use_device_scalars); + beta.resize(scalar_coord, use_device_scalars); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, use_device_scalars); + scale_B.resize(scalar_coord, use_device_scalars); + scale_C.resize(scalar_coord, use_device_scalars); + scale_D.resize(scalar_coord, use_device_scalars); + EXPECT_TRUE(impl_.initialize_tensor(scale_A.host_view(), init_scale, impl_.seed + 2023)); + EXPECT_TRUE(impl_.initialize_tensor(scale_B.host_view(), init_scale, impl_.seed + 2024)); + EXPECT_TRUE(impl_.initialize_tensor(scale_C.host_view(), init_scale, impl_.seed + 2025)); + EXPECT_TRUE(impl_.initialize_tensor(scale_D.host_view(), init_scale, impl_.seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); } - } - void initialize_bias(ProblemShapeType problem_size) { - auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto M = cute::get<0>(problem_shape_MNKL); - bias.resize(cutlass::Coord<1>(M)); + if constexpr (IsBiasEnabled) { + bias.resize(col_vector_coord); + EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), init_bias, impl_.seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabled) { + abs_max_D.resize(scalar_coord); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + } + + if constexpr (IsAuxEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + reference_Aux.resize(aux_coord, aux_layout, false); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, use_device_scalars); + EXPECT_TRUE(impl_.initialize_tensor(scale_Aux.host_view(), init_scale, impl_.seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabled) { + abs_max_Aux.resize(scalar_coord); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + } + } - EXPECT_TRUE(impl_.initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.seed + 2023)); - bias.sync_device(); } template < @@ -740,39 +933,38 @@ struct Testbed3xBiasElementwise { } /// Compares computed reference with device reference and outputs to a file if incorrect - bool compare_reference( - cute::Shape problem_shape_MNKL, - ElementScalar alpha, - ElementScalar beta) { + bool compare_reference(cute::Shape problem_shape_MNKL) { auto [M, N, K, L] = problem_shape_MNKL; auto coord_0 = cutlass::make_Coord(0); - impl_.tensor_D.sync_host(); - tensor_T.sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_A.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_B.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_C.host_view()), 0); + impl_.tensor_D.sync_host(); if (impl_.tensor_D.size() > 1) { EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.tensor_D.host_view()), 0); - } - - if (impl_.reference_D.size() > 1) { EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.reference_D.host_view()), 0); } + bool passed = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view()); - if constexpr (StoreT) { - EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); - EXPECT_GT(cutlass::reference::host::TensorNorm(reference_T.host_view()), 0); + if constexpr (IsAbsMaxEnabled) { + abs_max_D.sync_host(); + passed &= equality_check(reference_abs_max_D.host_view(), abs_max_D.host_view()); } - bool passed_D = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view()); - EXPECT_TRUE(passed_D); - - bool passed_T = StoreT ? equality_check(reference_T.host_view(), tensor_T.host_view()) : true; - EXPECT_TRUE(passed_T); + if constexpr (IsAuxEnabled) { + tensor_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + passed &= equality_check(reference_Aux.host_view(), tensor_Aux.host_view()); + if constexpr (IsAbsMaxEnabled) { + abs_max_Aux.sync_host(); + passed &= equality_check(reference_abs_max_Aux.host_view(), abs_max_Aux.host_view()); + } + } - bool passed = passed_D && passed_T; + EXPECT_TRUE(passed); if (!passed) { std::stringstream fname; fname << "error_Gemm_device_" @@ -783,35 +975,65 @@ struct Testbed3xBiasElementwise { std::ofstream file(fname.str()); file - << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L - << ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n"; - - if constexpr (IsBiasEnabled) { - file << "Bias = \n" << bias.host_view()<< "\n\n"; + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L; + if constexpr (IsScaleFactorEnabled) { + file + << ", scale_a: " << scale_A.at(coord_0) + << ", scale_b: " << scale_B.at(coord_0) + << ", scale_c: " << scale_C.at(coord_0); + } + if constexpr (IsPerRowScaleEnabled) { + file << "\n\nvalpha = \n" << alpha.host_view(); + file << "\n\nvbeta = \n" << beta.host_view(); + } else { + file + << ", alpha: " << alpha.at(coord_0) << ", beta: " << beta.at(coord_0); + } + file << "\n\n"; + + if constexpr (IsAbsMaxEnabled) { + file << "scale_d: " << float(scale_D.at(coord_0)); + file << "\nReference abs_max_D :"; + file << " " << float(reference_abs_max_D.at(coord_0)); + + file << "\nComputed abs_max_D :"; + file << " " << float(abs_max_D.at(coord_0)); + file << "\n\n"; + if constexpr (IsAuxEnabled) { + file << "scale_aux: " << float(scale_Aux.at(coord_0)); + file << "\nReference abs_max_Aux :"; + file << " " << float(reference_abs_max_Aux.at(coord_0)); + + file << "\nComputed abs_max_Aux :"; + file << " " << float(abs_max_Aux.at(coord_0)); + file << "\n\n"; + } } file << "A =\n" << impl_.tensor_A.host_view() << "\nB =\n" << impl_.tensor_B.host_view() << "\nC =\n" << impl_.tensor_C.host_view(); - if constexpr (StoreT) { + + if constexpr (IsBiasEnabled) { + file << "\n\nBias = \n" << bias.host_view(); + } + + if constexpr (IsAuxEnabled) { file - << "\n\nReference_T =\n" << reference_T.host_view() - << "\n\nComputed_T =\n" << tensor_T.host_view(); + << "\n\nReference Aux =\n" << reference_Aux.host_view() + << "\n\nComputed Aux =\n" << tensor_Aux.host_view(); } file - << "\n\nReference_D =\n" << impl_.reference_D.host_view() - << "\n\nComputed_D =\n" << impl_.tensor_D.host_view(); + << "\n\nReference D =\n" << impl_.reference_D.host_view() + << "\n\nComputed D =\n" << impl_.tensor_D.host_view(); } return passed; } /// Verifies the result against a reference implementation - bool verify( - ProblemShapeType problem_size, - ElementScalar alpha, - ElementScalar beta) + bool verify(ProblemShapeType problem_size) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto M = cute::get<0>(problem_shape_MNKL); @@ -828,43 +1050,81 @@ struct Testbed3xBiasElementwise { cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); auto D = cute::make_tensor(impl_.reference_D.host_data(), cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); - auto Bias = cute::make_tensor(static_cast(IsBiasEnabled ? bias.host_data() : nullptr), - cute::make_layout(cute::make_shape(M, 1))); - auto T = cute::make_tensor(static_cast(StoreT ? reference_T.host_data() : nullptr), - cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + auto Bias = cute::make_tensor(bias.host_data(), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Aux = cute::make_tensor(reference_Aux.host_data(), + cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); + auto Valpha = cute::make_tensor(alpha.host_data(), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Vbeta = cute::make_tensor(beta.host_data(), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; cutlass::reference::host::GettEpilogueParams< + ElementScalar, ElementScalar, ElementAccumulator, ElementCompute, decltype(C), decltype(D), decltype(Bias), - decltype(T), - ActivationFunctor, - BinaryOp> - epilogue_params{ - alpha, - beta, - C, - D, - Bias, - T - }; + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsBiasEnabled) { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabled) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabled) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (not disable_vector_beta) { + epilogue_params.Vbeta = Vbeta; + } + } cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - return compare_reference(problem_shape_MNKL, alpha, beta); + return compare_reference(problem_shape_MNKL); } /// Executes one test bool run( - ProblemShapeType problem_size, - ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0), - bool profiling = false, - int iterations = 20) + ProblemShapeType problem_size, + ElementScalar alpha_ = ElementScalar(1), + ElementScalar beta_ = ElementScalar(0), + detail::Splits splits = detail::Splits{}, + bool profiling = false, + detail::Iterations iterations = detail::Iterations{}) { // Fail test if insufficient CUDA device if (!impl_.sufficient()) { @@ -889,11 +1149,11 @@ struct Testbed3xBiasElementwise { /// Initializes data structures /// A/B/C/D Tensor - initialize(problem_size); + initialize(problem_size, alpha_, beta_); - /// bias - if constexpr (IsBiasEnabled){ - initialize_bias(problem_size); + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (std::is_same_v) { + scheduler_args = { static_cast(splits) }; } arguments = typename Gemm::Arguments{ @@ -904,20 +1164,67 @@ struct Testbed3xBiasElementwise { impl_.tensor_B.device_data(), impl_.stride_b }, { // Epilogue arguments - { - alpha, - beta - }, + {}, // thread impl_.tensor_C.device_data(), impl_.stride_c, impl_.tensor_D.device_data(), - impl_.stride_d, - bias.device_data(), - tensor_T.device_data() + impl_.stride_d }, // Epilogue arguments end - hw_info + hw_info, + scheduler_args }; + auto coord_0 = cutlass::make_Coord(0); + if constexpr (IsLegacy) { + arguments.epilogue.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.epilogue.ptr_Bias = bias.device_data(); + arguments.epilogue.ptr_T = tensor_Aux.device_data(); + } + else { + auto &fusion_args = arguments.epilogue.thread; + + fusion_args.alpha = alpha.at(coord_0); + fusion_args.beta = beta.at(coord_0); + fusion_args.alpha_ptr = alpha.device_data(); + fusion_args.beta_ptr = beta.device_data(); // if disable_vector_beta is true this is nullptr + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsAbsMaxEnabled) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabled) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + } + Gemm gemm_op; size_t workspace_size = Gemm::get_workspace_size(arguments); @@ -936,7 +1243,7 @@ struct Testbed3xBiasElementwise { // if (profiling) { - return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + return impl_.profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); } else { cudaError_t result; @@ -953,9 +1260,9 @@ struct Testbed3xBiasElementwise { // // Verify // - bool passed = this->verify(problem_size, alpha, beta); + bool passed = this->verify(problem_size); if (!passed) { - std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta) + std::cout << "Error : Failed : with alpha: " << float(alpha_) << ", beta: " << float(beta_) << "\n"; } @@ -968,10 +1275,10 @@ struct Testbed3xBiasElementwise { template < typename Gemm, - template class ActivationFunctor = cutlass::epilogue::thread::Identity + typename Testbed = Testbed3x > -bool TestAll() { - using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; +bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); @@ -989,28 +1296,39 @@ bool TestAll() { std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; - Testbed3x testbed; + std::vector problem_splits = {1}; + if constexpr (std::is_same_v) { + problem_splits.push_back(2); + problem_splits.push_back(3); + + // As many splits as there are maximum k tiles + problem_splits.push_back(Stages + 1); + } + bool passed = true; for (int m : problem_size_m) { for (int n : problem_size_n) { for (int k : problem_size_k) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(0) - ); - - if (!passed) { - return false; + for (int splits : problem_splits) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + detail::Splits(splits) + ); + + if (!passed) { + return false; + } } } } @@ -1021,8 +1339,8 @@ bool TestAll() { auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; passed = testbed.run( problem_size, - cutlass::from_real(1), - cutlass::from_real(0) + cutlass::from_real(alpha), + cutlass::from_real(beta) ); if (!passed) { @@ -1036,70 +1354,14 @@ bool TestAll() { ///////////////////////////////////////////////////////////////////////////////////////////////// template -bool TestAllBiasElementwise(bool check_relative_equality=false) { - using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; - using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; - - int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); - std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; - std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; +bool TestAllBiasElementwise(double alpha = 1.0, double beta = 0.0, bool check_relative_equality=false) { + Testbed3xFusionOperation testbed(check_relative_equality); - if constexpr (std::is_same_v) { - problem_size_m.push_back(768); - problem_size_n.push_back(768); - } - - constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; - constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); - - std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; - - Testbed3xBiasElementwise testbed(check_relative_equality); - bool passed = true; - - for (int m : problem_size_m) { - for (int n : problem_size_n) { - for (int k : problem_size_k) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(0) - ); - - if (!passed) { - return false; - } - } - } - } - - // if we do support batched GEMM, just run one test on it to save on test time - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; - passed = testbed.run( - problem_size, - cutlass::from_real(1), - cutlass::from_real(0) - ); - - if (!passed) { - return false; - } - } - - return passed; + return TestAll(alpha, beta, testbed); } ///////////////////////////////////////////////////////////////////////////////////////////////// + template bool TestGemmPerf3x(int iterations = 20) { using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; @@ -1129,7 +1391,7 @@ bool TestGemmPerf3x(int iterations = 20) { cutlass::from_real(1), cutlass::from_real(0), true, - iterations + detail::Iterations(iterations) ); if (!passed) { @@ -1148,7 +1410,7 @@ bool TestGemmPerf3x(int iterations = 20) { cutlass::from_real(1), cutlass::from_real(0), true, - iterations + detail::Iterations(iterations) ); if (!passed) { diff --git a/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp new file mode 100644 index 0000000000..c6d6da09f5 --- /dev/null +++ b/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -0,0 +1,1458 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed and host reference for EVT unittest +*/ + + +#pragma once +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +/// Host-side tapply, tapply in cute is HOST_DEVICE +template +constexpr auto +tapply(T&& t, F&& f, G&& g, cute::seq) +{ + return g(f(std::get(static_cast(t)))...); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT: Base class for EVT Node + +template < + typename Gemm_ +> +class HostEVTNodeBase { +public: + using Gemm = Gemm_; + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Kernel::CollectiveEpilogue; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; +private: + bool _check_relative_equality; + // Factors used for calculating relative equality. These default + // values are borrowed from those used by default in the CUTLASS + // profiler for performing relative equality checks. + float _epsilon = 0.05f; + float _nonzero_floor = 1.0f / 256.0f; + +public: + HostEVTNodeBase(){} + HostEVTNodeBase(bool check_relative_equality): + _check_relative_equality(check_relative_equality) { } + + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + if (_check_relative_equality) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, Element(_epsilon), Element(_nonzero_floor) + ); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + void* get_tensor_C_ptr() { + return nullptr; + } + + void* get_tensor_D_ptr() { + return nullptr; + } + + bool compare_reference(std::stringstream& error_ss) { + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Accumulator + +template < + typename Gemm +> +class HostAccumulator: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using TestBedImpl = typename Base::TestBedImpl; + using ElementAccumulator = typename Base::ElementAccumulator; + using ElementCompute = typename Base::ElementCompute; + + struct Arguments { }; + +private: + cutlass::NumericConverter accumulator_converter; +public: + HostAccumulator(){} + template + HostAccumulator(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :Base(check_relative_equality) {} + + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return accumulator_converter(acc); + } + + Arguments get_arguments() { + return Arguments{}; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Broadcast + +template < + typename Gemm, + int Value, + int BroadcastCount = 1, + template class ReductionFn = cutlass::multiplies +> +class HostScalarBroadcast : public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementCompute = typename Base::ElementCompute; + + struct Arguments { + ElementCompute scalar[BroadcastCount]; + ElementCompute const* scalar_ptrs[BroadcastCount]; + cute::Stride dScalar; + }; +private: + ElementCompute _scalar; +public: + HostScalarBroadcast(){} + template + HostScalarBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :_scalar(ElementCompute(Value)), Base(check_relative_equality) {} + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return _scalar; + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss << "Scalar: " << float(_scalar) << "\n\n"; + return true; + } + + Arguments get_arguments() { + if constexpr (BroadcastCount == 1) + return Arguments{{_scalar}, {nullptr}}; + else if constexpr (BroadcastCount == 2) + return Arguments{{_scalar, _scalar}, {nullptr, nullptr}}; + else if constexpr (BroadcastCount == 3) + return Arguments{{_scalar, _scalar, _scalar}, {nullptr, nullptr, nullptr}}; + else + return Arguments{{_scalar}, {nullptr}}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Broadcast +template < + typename Gemm, + typename ElementBias_=void +> +class HostRowBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = std::conditional_t, + typename Base::ElementC, + ElementBias_>; + + using TestBedImpl = typename Base::TestBedImpl; + using ElementCompute = typename Base::ElementCompute; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + cute::Stride dRow = {}; + }; +private: + cutlass::NumericConverter _bias_converter; + cutlass::HostTensor _bias; + int _N; + TestBedImpl impl_; +public: + HostRowBroadcast(){} + template + HostRowBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :impl_(impl), Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + _N = cute::get<1>(problem_shape_MNKL); + _bias.resize(cutlass::Coord<1>(_N)); + + EXPECT_TRUE( + impl_.initialize_tensor( + _bias.host_view(), cutlass::Distribution::Uniform, + impl_.seed + 2023 + ) + ); + _bias.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(_bias.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, _N))); + + return _bias_converter(TensorBias(1, n + n_b)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerColumnBias = \n" << _bias.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {_bias.device_data()}; + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Broadcast +template < + typename Gemm, + typename ElementBias_=void +> +class HostColBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = std::conditional_t, + typename Base::ElementC, + ElementBias_>; + + using TestBedImpl = typename Base::TestBedImpl; + using ElementCompute = typename Base::ElementCompute; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + cute::Stride dRow = {}; + }; +private: + cutlass::NumericConverter _bias_converter; + cutlass::HostTensor _bias; + int _M; + TestBedImpl impl_; +public: + HostColBroadcast(){} + template + HostColBroadcast(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :impl_(impl), Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + _M = cute::get<0>(problem_shape_MNKL); + _bias.resize(cutlass::Coord<1>(_M)); + + EXPECT_TRUE( + impl_.initialize_tensor( + _bias.host_view(), cutlass::Distribution::Uniform, + impl_.seed + 2023 + ) + ); + _bias.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(_bias.host_data(), + cute::make_layout(cute::make_shape(_M, cute::_1{}))); + + return _bias_converter(TensorBias(m + m_b, 1)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerRowBias = \n" << _bias.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {_bias.device_data()}; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Load + +template < + typename Gemm, + bool isC=false, + typename ElementAuxLoad_=void, + typename LayoutTagAux_=void +> +class HostAuxLoad: public HostEVTNodeBase { +public: + using ElementAuxLoad = std::conditional_t, + typename HostEVTNodeBase::ElementC, + ElementAuxLoad_>; + using LayoutTagAux = std::conditional_t, + typename HostEVTNodeBase::LayoutTagC, + LayoutTagAux_>; + + using Base = HostEVTNodeBase; + using TestBedImpl = typename Base::TestBedImpl; + using ElementCompute = typename Base::ElementCompute; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + ElementAuxLoad const *ptr_aux = nullptr; + ElementAuxLoad null_default = ElementAuxLoad(0); + StrideAux dAux = {}; + }; + + struct Arguments_C {}; + + using Arguments = cute::conditional_t; + +private: + cutlass::NumericConverter _aux_load_converter; + cutlass::HostTensor _tensor_aux_load; + + int _M, _N, _L; + + TestBedImpl impl_; + + StrideAux _stride_aux; +public: + HostAuxLoad(){} + template + HostAuxLoad(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :impl_(impl), Base(check_relative_equality){ + auto problem_shape_NMKL = cute::append<4>(problem_size, 1); + auto [_M, _N, K, _L] = problem_shape_NMKL; + auto aux_coord = cutlass::make_Coord(_M * _L, _N); + _tensor_aux_load.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + EXPECT_TRUE( + impl_.initialize_tensor( + _tensor_aux_load.host_view(), + cutlass::Distribution::Uniform, + impl_.seed + 2023 + ) + ); + _tensor_aux_load.sync_device(); + _stride_aux = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(_M, _N, _L)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + + auto TensorAuxLoad = cute::make_tensor(_tensor_aux_load.host_data(), + cute::make_layout(cute::make_shape(_M, _N, _L), _stride_aux)); + return _aux_load_converter(TensorAuxLoad(m + m_b, n + n_b, l)); + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (!isC) { + error_ss + << "AuxLoad = \n" << _tensor_aux_load.host_view()<< "\n\n"; + } + return true; + } + + void* get_tensor_C_ptr() { + if constexpr (isC) { + return static_cast(_tensor_aux_load.device_data()); + } else { + return nullptr; + } + } + + Arguments get_arguments() { + if constexpr (isC) + return {}; + else + return {_tensor_aux_load.device_data(), ElementAuxLoad(0), _stride_aux}; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Compute + +template +T* findNonNullPtr(T* first_ptr) { + return first_ptr; +} + +template +T* findNonNullPtr(T* first_ptr, Args... args) { + if (first_ptr) { + return first_ptr; + } + return findNonNullPtr(args...); +} + +template < + typename Gemm, + template class ComputeOp_ +> +class HostCompute: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementCompute = typename Base::ElementCompute; + using ComputeOp = ComputeOp_; + + struct Arguments { + struct OpArgs {} op; + }; +private: + ComputeOp _op; +public: + HostCompute(){} + template + HostCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): + Base(check_relative_equality) { } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, Args... frg_inputs) { + return _op(frg_inputs...); + } + + Arguments get_arguments(){ + return {}; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Unary Compute + +template < + typename Gemm, + template class ComputeOp_, + typename Child0 +> +class HostUnaryCompute: public HostEVTNodeBase { +public: + + using Base = HostEVTNodeBase; + using ElementCompute = typename Base::ElementCompute; + using ComputeOp = ComputeOp_; + + struct Arguments { + typename Child0::Arguments child_0_args; + struct OpArgs {} op; + }; +private: + ComputeOp _op; + Child0 _child_0; +public: + HostUnaryCompute(){} + template + HostUnaryCompute(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): + _child_0(problem_size, impl, check_relative_equality), + Base(check_relative_equality) { } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + ElementCompute child_0_result = _child_0.visit(m, n, l, m_b, n_b, acc); + return _op(child_0_result); + } + + Arguments get_arguments(){ + return { + _child_0.get_arguments(), + {}, + }; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Store + +template < + typename Gemm, + bool isD=false, + class ElementAuxStore_=void, + typename LayoutTagAux_=void +> +class HostAuxStore: public HostEVTNodeBase { +public: + using ElementAuxStore = std::conditional_t, + typename HostEVTNodeBase::ElementD, + ElementAuxStore_>; + using LayoutTagAux = std::conditional_t, + typename HostEVTNodeBase::LayoutTagD, + LayoutTagAux_>; + + using Base = HostEVTNodeBase; + using TestBedImpl = typename Base::TestBedImpl; + using ElementCompute = typename Base::ElementCompute; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + struct OpArgs { + ElementAuxStore* ptr_aux = nullptr; + StrideAux dAux = {}; + } op; + }; + + struct Arguments_D {}; + + using Arguments = cute::conditional_t; + + +private: + cutlass::NumericConverter destination_converter; + cutlass::HostTensor _tensor_aux_store; + cutlass::HostTensor _reference_aux_store; + int _M, _N, _L; + TestBedImpl impl_; + StrideAux _stride_aux; +public: + HostAuxStore(){} + template + HostAuxStore(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): + impl_(impl), + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [_M, _N, K, _L] = problem_shape_MNKL; + auto aux_coord = cutlass::make_Coord(_M * _L, _N); + _tensor_aux_store.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + + _reference_aux_store.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + _tensor_aux_store.sync_device(); + _stride_aux = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(_M, _N, _L)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + + auto TensorAuxStore = cute::make_tensor(static_cast(_reference_aux_store.host_data()), + cute::make_layout(cute::make_shape(_M, _N, _L), _stride_aux)); + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter(child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + _tensor_aux_store.sync_host(); + + bool equal = this->equality_check(_reference_aux_store.host_view(), _tensor_aux_store.host_view()); + if (!equal) { + error_ss + << "\n\nReference =\n" << _reference_aux_store.host_view() + << "\n\nComputed =\n" << _tensor_aux_store.host_view() << "\n\n"; + } + return equal; + } + + void* get_tensor_D_ptr() { + if constexpr (isD) + return static_cast(_tensor_aux_store.device_data()); + else + return nullptr; + } + + Arguments get_arguments() { + if constexpr (isD) { + return {}; + } else { + return {_tensor_aux_store.device_data(), _stride_aux}; + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Reduce + +template < + typename Gemm, + template class ReduceFn, + typename ElementReduce +> +class HostRowReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using TestBedImpl = typename Base::TestBedImpl; + using ElementCompute = typename Base::ElementCompute; + using ElementOutput = typename Base::ElementD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_row = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter; + cutlass::HostTensor _tensor_row_reduce; + cutlass::HostTensor _reduce_buffer; + cutlass::HostTensor _reference_row_reduce; + int _N; + TestBedImpl impl_; + ReduceFn reduce_fn; +public: + HostRowReduce(){} + template + HostRowReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): + impl_(impl), + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + _N = cute::get<1>(problem_shape_MNKL); + _tensor_row_reduce.resize(cutlass::Coord<1>(_N)); + _reference_row_reduce.resize(cutlass::Coord<1>(_N)); + _reduce_buffer.resize(cutlass::Coord<1>(_N)); + + _tensor_row_reduce.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorRowReduce = cute::make_tensor(_reduce_buffer.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, _N))); + TensorRowReduce(1, n + n_b) = reduce_fn(TensorRowReduce(1, n + n_b), child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + _tensor_row_reduce.sync_host(); + + auto TensorRowReduce = cute::make_tensor(_reference_row_reduce.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, _N))); + + auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, _N))); + + // Filling the reference tensor with the reduce buffer + for (int n = 0; n < _N; n ++) { + TensorRowReduce(1, n) = destination_converter(TensorReduceBuffer(1, n)); + } + + bool equal = this->equality_check(_reference_row_reduce.host_view(), _tensor_row_reduce.host_view()); + if (!equal) { + error_ss + << "\n\nRow Reduce Reference =\n" << _reference_row_reduce.host_view() + << "\n\nRow Reduce Computed =\n" << _tensor_row_reduce.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {_tensor_row_reduce.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Reduce + +template < + typename Gemm, + template class ReduceFn, + typename ElementReduce +> +class HostColumnReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using TestBedImpl = typename Base::TestBedImpl; + using ElementCompute = typename Base::ElementCompute; + using ElementOutput = typename Base::ElementD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_col = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter; + cutlass::HostTensor _tensor_column_reduce; + cutlass::HostTensor _reduce_buffer; + cutlass::HostTensor _reference_column_reduce; + int _M; + TestBedImpl impl_; + ReduceFn reduce_fn; +public: + HostColumnReduce(){} + template + HostColumnReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): + impl_(impl), + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + _M = cute::get<0>(problem_shape_MNKL); + _tensor_column_reduce.resize(cutlass::Coord<1>(_M)); + _reference_column_reduce.resize(cutlass::Coord<1>(_M)); + _reduce_buffer.resize(cutlass::Coord<1>(_M)); + + _tensor_column_reduce.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorColReduce = cute::make_tensor(_reduce_buffer.host_data(), + cute::make_layout(cute::make_shape(_M, cute::_1{}))); + TensorColReduce(m + m_b, 1) = reduce_fn(TensorColReduce(m + m_b, 1), child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + _tensor_column_reduce.sync_host(); + + auto TensorColReduce = cute::make_tensor(_reference_column_reduce.host_data(), + cute::make_layout(cute::make_shape(_M, cute::_1{}))); + + auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(), + cute::make_layout(cute::make_shape(_M, cute::_1{}))); + + // Filling the reference tensor with the reduce buffer + for (int m = 0; m < _M; m ++) { + TensorColReduce(m, 1) = destination_converter(TensorReduceBuffer(m, 1)); + } + + bool equal = this->equality_check(_reference_column_reduce.host_view(), _tensor_column_reduce.host_view()); + if (!equal) { + error_ss + << "\n\nColumn Reduce Reference =\n" << _reference_column_reduce.host_view() + << "\n\nColumn Reduce Computed =\n" << _tensor_column_reduce.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {_tensor_column_reduce.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Reduce + +template < + typename Gemm, + template class ReduceFn, + typename ElementReduce +> +class HostScalarReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using TestBedImpl = typename Base::TestBedImpl; + using ElementCompute = typename Base::ElementCompute; + using ElementOutput = typename Base::ElementD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_scalar = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dScalar = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter; + cutlass::HostTensor _tensor_scalar_reduce; + cutlass::HostTensor _reduce_buffer; + cutlass::HostTensor _reference_scalar_reduce; + ReduceFn reduce_fn; + TestBedImpl impl_; +public: + HostScalarReduce(){} + template + HostScalarReduce(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false): + impl_(impl), + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + _tensor_scalar_reduce.resize(cutlass::Coord<1>(1)); + _reference_scalar_reduce.resize(cutlass::Coord<1>(1)); + _reduce_buffer.resize(cutlass::Coord<1>(1)); + + _tensor_scalar_reduce.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorRowReduce = cute::make_tensor(_reduce_buffer.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + TensorRowReduce(0) = reduce_fn(TensorRowReduce(0), child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + _tensor_scalar_reduce.sync_host(); + + auto TensorRowReduce = cute::make_tensor(_reference_scalar_reduce.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + auto TensorReduceBuffer = cute::make_tensor(_reduce_buffer.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + // Filling the reference tensor with the reduce buffer + TensorRowReduce(0) = destination_converter(TensorReduceBuffer(0)); + + bool equal = this->equality_check(_reference_scalar_reduce.host_view(), _tensor_scalar_reduce.host_view()); + if (!equal) { + error_ss + << "\n\nScalar Reduce Reference =\n" << _reference_scalar_reduce.host_view() + << "\n\nScalar Reduce Computed =\n" << _tensor_scalar_reduce.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {_tensor_scalar_reduce.device_data()}; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Host EVT wrapper + +/// The ArgumentPack is used to model the alignment when num ops <= 4 +template +struct ArgumentPack; + +template +struct ArgumentPack { + T arg; + ArgumentPack(T first): + arg(first) {} +}; + +template +struct ArgumentPack { + First arg; + ArgumentPack rest_args; + + ArgumentPack(First first, Rest... rest) : + arg(first), rest_args(rest...) {} +}; + + +/// Base class for Host Visitor +template +struct HostVisitorBase: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementCompute = typename Base::ElementCompute; + + using Arguments_struct = ArgumentPack; + using Arguments_tuple = cute::tuple; + + constexpr static int Rm1 = sizeof...(Ops); + constexpr static bool cond = Rm1 > 4; + using Arguments = cute::conditional_t; + + std::tuple ops; + + HostVisitorBase(){} + template + HostVisitorBase(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :Base(check_relative_equality), + ops(test::gemm::device::tapply(std::tuple{}, + [&] (auto&& op) { + using Op = cute::remove_cvref_t; + return Op(problem_size, impl, check_relative_equality); + }, + [] (auto&&... _ops) { + return std::make_tuple(_ops...); + }, + cute::make_seq{} + )){ } + + bool compare_reference(std::stringstream& error_ss) { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.compare_reference(error_ss); + }, + [&] (auto&&... inputs) { + return arrayAnd(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_C_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_C_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_D_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_D_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + Arguments get_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_arguments(); + }, + [&] (auto&&... args) { + if constexpr (Rm1 > 4) { + return cute::make_tuple(args...); + } else { + return Arguments(args...); + } + }, + cute::make_seq{} + ); + } + + bool arrayAnd(bool passed) { + return passed; + } + + template + bool arrayAnd(bool first_passed, Args... passed) { + if (first_passed) { + return arrayAnd(passed...); + } + return first_passed; + } + +}; + + +/// Tree-struct visitor +template +struct HostTreeVisitor: public HostVisitorBase { +public: + using Gemm = typename NodeOp::Base::Gemm; + using Base = HostVisitorBase; + using ElementCompute = typename Base::ElementCompute; + using Arguments = typename Base::Arguments; + + constexpr static int Rm1 = sizeof...(ChildOps); + + HostTreeVisitor(){} + template + HostTreeVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :Base(problem_size, impl, check_relative_equality){ } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + return cute::detail::tapply(this->ops, + [&] (auto& op) { + return op.visit(m, n, l, m_b, n_b, acc); + }, + [&] (auto&&... frg_inputs) { + return std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + }, + cute::make_seq{} + ); + } +}; + + +/// General Graph visitor +template +struct HostTopoVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + using ElementCompute = typename Base::ElementCompute; + constexpr static int Rm1 = Base::Rm1; + using Arguments = typename Base::Arguments; + +private: + ElementCompute frg_outputs[Rm1]; +public: + HostTopoVisitor(){} + template + HostTopoVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :Base(problem_size, impl, check_relative_equality) { } + + template + ElementCompute visit_( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + frg_outputs[I] = cute::transform_apply(cute::get(EdgeTuple{}), + [&] (auto&& _E) { + constexpr int e = cute::remove_cvref_t::value; + return frg_outputs[e]; + }, + [&] (auto const&... frg_inputs) { + ElementCompute res = std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + return res; + } + ); + + if constexpr (I < Rm1 - 1) { + return visit_(m, n, l, m_b, n_b, acc); + } else { + return frg_outputs[I]; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return visit_(m, n, l, m_b, n_b, acc); + } + +}; + + +/// SplitTree visitor +template +struct HostSplitTreeVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + using ElementCompute = typename Base::ElementCompute; + using Arguments = typename Base::Arguments; + + constexpr static int Rm2 = sizeof...(AuxOutTrees); + +private: + ElementCompute frg_input; +public: + HostSplitTreeVisitor(){} + template + HostSplitTreeVisitor(ProblemShapeType problem_size, TestBedImpl impl, bool check_relative_equality=false) + :Base(problem_size, impl, check_relative_equality) { } + + template + void visitAux( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator frag) { + std::get(this->ops).visit(m, n, l, m_b, n_b, frag); + + if constexpr (I < Rm2 - 1) { + return visitAux(m, n, l, m_b, n_b, frag); + } else { + return; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + /// Compute the input tree + frg_input = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc); + + /// Compute the aux out tree + visitAux(m, n, l, m_b, n_b, frg_input); + /// Visit the output tree + return std::get(this->ops).visit(m, n, l, m_b, n_b, frg_input); + } +}; + +/// Universal testbed for EVT +template +class Testbed3xEVT { +public: + // The EVT Module to test + using EVTModule = typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + + // + // Methods + // + Testbed3xEVT( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(check_relative_equality_) { } + + Testbed3xEVT( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(init_A_, init_B_, init_C_, seed_), check_relative_equality(false) { } + + Testbed3xEVT( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + init_A_, + init_B_, + init_C_, + seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + + auto A = cute::make_tensor(impl_.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); + auto B = cute::make_tensor(impl_.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.tensor_A.host_view() + << "\nB =\n" << impl_.tensor_B.host_view() + << "\nC =\n" << impl_.tensor_C.host_view() << "\n\n"; + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + bool profiling = false, + int iterations = 20, + int splits = 1) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (std::is_same_v) { + scheduler_args = { splits }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, impl_, check_relative_equality); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.tensor_A.device_data(), impl_.stride_a, + impl_.tensor_B.device_data(), impl_.stride_b + }, + { // Epilogue arguments + {}, // thread + static_cast(host_reference.get_tensor_C_ptr()), + impl_.stride_c, + static_cast(host_reference.get_tensor_D_ptr()), + impl_.stride_d + }, // Epilogue arguments end + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + typename EVTModule::Arguments epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg)); + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + + +template +bool TestAllEVT(bool check_relative_equality=false) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (std::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xEVT testbed(check_relative_equality); + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run(problem_size); + + if (!passed) { + return false; + } + } + } + } + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp index 3e5424a107..145f874725 100644 --- a/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -247,21 +247,33 @@ struct Testbed3xTensorBroadcast { auto dummy_C = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); ElementCompute dummy_beta(0); + auto dummy_Aux = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); + auto dummy_Valpha = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, 1))); + auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, 1))); + cutlass::reference::host::GettEpilogueParams< + ElementScalar, ElementScalar, ElementAccumulator, ElementCompute, decltype(dummy_C), decltype(RefComputeOut), decltype(Bias), - decltype(dummy_C), + decltype(dummy_Aux), + decltype(dummy_Valpha), + decltype(dummy_Vbeta), ActivationFunctor> epilogue_params{ alpha, dummy_beta, dummy_C, RefComputeOut, Bias, - dummy_C + dummy_Aux, + dummy_Valpha, + dummy_Vbeta }; cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); @@ -347,7 +359,8 @@ struct Testbed3xTensorBroadcast { cutlass::gemm::GemmUniversalMode::kGemm, problem_size, { impl_.tensor_A.device_data(), impl_.stride_a, - impl_.tensor_B.device_data(), impl_.stride_b + impl_.tensor_B.device_data(), impl_.stride_b, + impl_.mma_promotion_interval }, { // Epilogue arguments { alpha, beta }, // ThreadOp arguments diff --git a/test/unit/gemm/device/gemv.cu b/test/unit/gemm/device/gemv.cu index 8cac590a31..253b72d7b7 100644 --- a/test/unit/gemm/device/gemv.cu +++ b/test/unit/gemm/device/gemv.cu @@ -146,7 +146,6 @@ public: view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu index 9a115493c4..99b446af64 100644 --- a/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -49,7 +49,6 @@ #include "testbed_symm_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Hemm_cf64h_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu index fbc4efdb93..ec0e03fa84 100644 --- a/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/her2k_cf64_cf64_tensor_op_f64_sm90.cu @@ -47,7 +47,6 @@ #include "testbed_rank2k_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Her2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu index 114a20cf9c..c853ed4cf8 100644 --- a/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/herk_cf64_cf64_tensor_op_f64_sm90.cu @@ -47,7 +47,6 @@ #include "testbed_rank_k_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// // HERK operator on CUBLAS_OP_C (row-major + conj) input layouts TEST(SM90_Device_Herk_cf64h_cf64n_l_tensor_op_f64, 64x64x16_32x32x16) { diff --git a/test/unit/gemm/device/multistage_testbed.h b/test/unit/gemm/device/multistage_testbed.h index 681e051e31..f0f71bb8f6 100644 --- a/test/unit/gemm/device/multistage_testbed.h +++ b/test/unit/gemm/device/multistage_testbed.h @@ -100,7 +100,6 @@ struct MultistageTestbed { cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/multistage_testbed_interleaved.h b/test/unit/gemm/device/multistage_testbed_interleaved.h index 5f332069d1..2556c0174b 100644 --- a/test/unit/gemm/device/multistage_testbed_interleaved.h +++ b/test/unit/gemm/device/multistage_testbed_interleaved.h @@ -105,7 +105,6 @@ struct MultistageInterleavedTestbed { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/sm90_evt_operations.hpp b/test/unit/gemm/device/sm90_evt_operations.hpp new file mode 100644 index 0000000000..767e84f09a --- /dev/null +++ b/test/unit/gemm/device/sm90_evt_operations.hpp @@ -0,0 +1,510 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Host reference and operations for Sm90 EVT unit test +*/ +#pragma once +#include "gemm_testbed_3x_evt.hpp" + +////////////////////////////////////////////////////////////////////////////// +/// Host references used for testing +namespace test::gemm::device { +template +using HEVT = HostTreeVisitor; + +template +using HDAG = HostTopoVisitor; + +template +using HST = HostSplitTreeVisitor; + +/// D = alpha * acc + beta * C + AuxLoad +template +class HostEVTAuxLoad { +public: + using ScalarAlpha = HostScalarBroadcast; + using AccFetchNode = HostAccumulator; + using AuxLoadNode = HostAuxLoad; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; + using ScalarBeta = HostScalarBroadcast; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = alpha * acc + beta * C + per-column bias +template +class HostPerColBias { +public: + using ScalarAlpha = HostScalarBroadcast; + using AccFetchNode = HostAccumulator; + using RowBroadcastNode = HostRowBroadcast; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; + using ScalarBeta = HostScalarBroadcast; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = beta * C + Graph(relu(alpha * acc + aux) + aux) +/// Testing EVT - DAG structure +template +class HostEVTDAG { +public: + using ScalarAlpha = HostScalarBroadcast; + using AccFetchNode = HostAccumulator; + using AuxLoadNode = HostAuxLoad; + using DAGNode = HDAG< + Gemm, + cute::tuple< + cute::tuple<>, // 0. alpha + cute::tuple<>, // 1. acc + cute::tuple<>, // 2. aux load + cute::tuple, // 3. alpha * acc + aux load + cute::tuple, // relu(alpha * acc + aux load) + cute::tuple // relu(alpha * acc + aux load) + aux load + >, + ScalarAlpha, + AccFetchNode, + AuxLoadNode, + HostCompute, + HostCompute, + HostCompute + >; + using ScalarBeta = HostScalarBroadcast; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +/// Testing DAG - EVT +template +class HostDAGEVT { +public: + using EVTNode = HEVT< + HostAuxStore, + HEVT< + HostCompute, + HostScalarBroadcast, + HostAccumulator, + HostAuxLoad + > + >; + using EVTModule = HEVT< + HostAuxStore, + HDAG< + Gemm, + cute::tuple< + cute::tuple<>, // 0. EVT + cute::tuple<>, // 1. per-row bias + cute::tuple, // 2. EVT + per-row bias + cute::tuple // 3. maximum(EVT + per-row bias, EVT) + >, + EVTNode, + HostColBroadcast, + HostCompute, + HostCompute + > + >; +}; + +/// Xreduce(alpha * acc + beta * C) +template class, class> class ReduceOp> +class HostReduce { +public: + using ScalarAlpha = HostScalarBroadcast; + using AccFetchNode = HostAccumulator; + using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; + using ScalarBeta = HostScalarBroadcast; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; + using ReduceNode = HEVT, TernaryCompute1>; + using EVTModule = HEVT, ReduceNode>; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template class ActivationFn, class ElementD> +class HostScaledLinCombPerRowBiasEltAct { +public: + using EVTModule = HEVT< + HostAuxStore, + HEVT< + HostCompute::Op>, // activation(Z) * scaled_d + HEVT< + HostCompute, // activation(Z) + HEVT< + HostCompute, + HostScalarBroadcast, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast, // scale_a * scale_b * alpha + HostAccumulator, + HostColBroadcast, + > + > + >, + HostScalarBroadcast, // scale_d + > + >; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template class ActivationFn, class ElementD> +class HostScaledLinCombPerRowBiasEltActAmaxAux { +public: + template + using amax = cutlass::maximum_absolute_value_reduction; + using EVTModule = HEVT< + HostAuxStore, + HST, + HostScalarBroadcast, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast, // scale_a * scale_b * alpha + HostAccumulator, + HostColBroadcast, + > + >, + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HEVT< + HostCompute::Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HostAccumulator, // Z + > + >, + HostScalarBroadcast, // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + HEVT< + HostAuxStore, + HEVT< + HostCompute::Op>, + HEVT< + HostScalarReduce, + HostAccumulator + >, + HostScalarBroadcast + > + > + > + >; +}; +} // namespace test::gemm::device + +////////////////////////////////////////////////////////////////////////////// +namespace cutlass::epilogue { +namespace fusion { + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoad +template< + class EpilogueDescriptor, + class AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoad = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, + typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, + typename AuxLoadDescriptor::CopyOpS2R // aux load + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// beta * C + Graph(alpha * acc + gamma + acc) +template< + typename EpilogueDescriptor, + typename AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEVTDAG = + Sm90EVT, // beta * C + (alpha * acc + aux) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, // 0. alpha + cute::seq<>, // 1. acc + cute::seq<>, // 2. aux load + cute::seq<1, 0, 2>, // 3. alpha * acc + aux load + cute::seq<3>, // relu(alpha & acc + aux load) + cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load + >, + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, + typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, + Sm90Compute, + Sm90Compute, + Sm90Compute + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +template< + class EpilogueDescriptor, + class AuxStoreDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDAGEVT = + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, + cute::seq<>, + cute::seq<1, 0>, + cute::seq<0, 2> + >, + Sm90EVT< + Sm90AuxStore< + AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, + typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, + Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch, + Sm90SrcFetch + > + >, + Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias>, + Sm90Compute, + Sm90Compute + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + per-column bias +template< + class EpilogueDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast< + ceil_div( + EpilogueDescriptor::StagesC, + size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) + ) + 1, + typename EpilogueDescriptor::TileShape, + ElementBias + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-column reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-row reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = scalar reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombScalarReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; +} // namespace fusion + +namespace collective { + +template< + typename TileShape_MNK, + typename EpilogueTileType, + typename ElementC, + typename ElementD, + typename Schedule +> +struct EpilogueDescriptor{ + using TileShape = TileShape_MNK; + using EpilogueTile = + decltype(detail::sm90_compute_tile_shape_or_override()); + using DispatchPolicy = + decltype(detail::sm90_get_tma_dispatch_policy()); + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; +}; + + +template< + typename EpilogueDescriptor, + typename GmemLayoutTagAux, + typename ElementAux +> +struct AuxLoadDescriptor{ + constexpr static int Stages = EpilogueDescriptor::StagesC; + using Element = ElementAux; + using Stride = gemm::TagToStrideC_t; + using SmemLayoutAtom = + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); + using CopyOpS2R = + decltype(detail::sm90_get_smem_load_op_for_source()); +}; + + +template< + typename EpilogueDescriptor, + typename GmemLayoutTagAux, + typename ElementAux +> +struct AuxStoreDescriptor{ + constexpr static int Stages = EpilogueDescriptor::StagesD; + using Element = ElementAux; + using Stride = gemm::TagToStrideC_t; + using SmemLayoutAtom = + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); + using CopyOpR2S = + decltype(detail::sm90_get_smem_store_op_for_accumulator()); +}; + +} // namespace collective + +} // namespace cutlass::epilogue diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu index 942d1862de..1b3640b7da 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized.cu @@ -91,7 +91,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -128,7 +128,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -165,7 +165,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -202,7 +202,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -242,7 +242,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -279,7 +279,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -316,7 +316,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -353,7 +353,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } @@ -394,7 +394,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -431,7 +431,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -468,7 +468,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -505,7 +505,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } @@ -546,7 +546,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -583,7 +583,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -620,7 +620,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } /////////////////////////////////////////////////////////////////////////////// @@ -657,7 +657,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_warpspecialized, 64x128x >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu index 9c93188b66..dcb4a0d9e9 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu @@ -807,7 +807,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_cooperative_epilogue, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1) { @@ -844,7 +844,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_cooperative_epilogue, 25 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1) { @@ -881,7 +881,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1) { @@ -918,7 +918,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 12 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_load.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_load.cu new file mode 100644 index 0000000000..aa8ab75947 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_aux_load.cu @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 with cooperative EVT epilogue + D = alpha * acc + beta * c + aux_load +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_AuxLoadF16_RowMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t + >; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombAuxLoad< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostEVTAuxLoad< + Gemm, cutlass::half_t, cutlass::layout::RowMajor + >; + bool passed = test::gemm::device::TestAllEVT(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_AuxLoadF16_ColumnMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t + >; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombAuxLoad< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostEVTAuxLoad< + Gemm, cutlass::half_t, cutlass::layout::ColumnMajor + >; + bool passed = test::gemm::device::TestAllEVT(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1_AuxLoadF32_ColumnMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::ColumnMajor, float + >; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombAuxLoad< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostEVTAuxLoad< + Gemm, float, cutlass::layout::ColumnMajor + >; + bool passed = test::gemm::device::TestAllEVT(); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu index d95b14a252..a24d8d2b34 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu @@ -51,7 +51,6 @@ #include "../../common/cutlass_unit_test.h" -#include "testing_elementwise.hpp" #include "gemm_testbed_3x.hpp" @@ -66,8 +65,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeElementwise< - cutlass::epilogue::thread::ReLu>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombEltAct< + cutlass::epilogue::thread::ReLu, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -76,7 +76,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -97,11 +98,18 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAll(); + test::gemm::device::Testbed3x testbed; + bool passed = test::gemm::device::TestAll(1, 1, testbed); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_Legacy) { +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" // Suppress deprecation warnings +#ifdef _MSC_VER +#pragma warning( push ) +#pragma warning( disable : 4996 ) +#endif // _MSC_VER using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; @@ -140,20 +148,24 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); +#ifdef _MSC_VER +#pragma warning( pop ) +#endif // _MSC_VER +#pragma GCC diagnostic pop // Re-enable deprecation warnings } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_GELU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - cutlass::epilogue::thread::GELU, cutlass::half_t, cutlass::plus, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -162,7 +174,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -183,21 +196,20 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool check_relative_equality = true; - bool passed = test::gemm::device::TestAllBiasElementwise(check_relative_equality); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_NoStoreT) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_GELU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = false; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::plus, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::GELU, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -206,7 +218,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -227,21 +240,21 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool check_relative_equality = true; + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1, check_relative_equality); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_Negate) { - +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_NoStoreT) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::RowMajor; using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - test::gemm::device::detail::Negate, cutlass::half_t, cutlass::plus, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLu, cutlass::half_t, float, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -250,7 +263,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -271,21 +285,21 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_Negate) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::negate, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -294,7 +308,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -315,21 +330,21 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 25 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -338,7 +353,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -359,11 +375,11 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF32_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -371,9 +387,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -382,7 +398,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, void, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -407,7 +424,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasF16_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -415,9 +432,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -426,7 +443,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, void, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -451,7 +469,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -459,9 +477,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using TileShape_MNK = Shape<_256,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperativeBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, int8_t>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -470,7 +488,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 float, float, void, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -495,4 +514,4 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) \ No newline at end of file +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu new file mode 100644 index 0000000000..9fd58b1d17 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu @@ -0,0 +1,170 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 cooperative DAG epilogue + EVTDAG: D = beta * C + Graph(relu(alpha * acc + aux) + aux) + DAGEVT: EVT = alpha * acc + C, D = Graph(maximum(EVT + per-row bias, EVT)) +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_EVTDAG) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombEVTDAG< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x64_2x2x1_DAGEVT) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using AuxStoreDescriptor = cutlass::epilogue::collective::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombDAGEVT< + EpilogueDescriptor, AuxStoreDescriptor, cutlass::half_t, float, cutlass::half_t, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_reduce.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_reduce.cu new file mode 100644 index 0000000000..998a2b7bc9 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_reduce.cu @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 cooperative EVT epilogue + D = row|column|scalar_reduce(alpha * acc + beta * C) +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_RowReduce) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerColumnReduce< + cutlass::plus, cutlass::red, float, TileShape_MNK, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostReduce; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_ColumnReduce) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerRowReduce< + cutlass::plus, cutlass::red, float, TileShape_MNK, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostReduce; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_ScalarReduce) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombScalarReduce< + cutlass::plus, cutlass::red, float, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostReduce; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_row_broadcast.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_row_broadcast.cu new file mode 100644 index 0000000000..b56a7db746 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_row_broadcast.cu @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 cooperative EVT epilogue + D = alpha * acc + beta * C + per_column_bias +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_RowBroadcastF16) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerColumnBias< + EpilogueDescriptor, cutlass::half_t, float, cutlass::half_t, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_RowBroadcastF32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerColumnBias< + EpilogueDescriptor, cutlass::half_t, float, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu index 472ec9a7b9..171d4abec6 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong.cu @@ -1144,7 +1144,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_persistent_epilogue, 128 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) { @@ -1187,7 +1187,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_persistent_epilogue, 128 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) { @@ -1230,7 +1230,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1) { @@ -1273,7 +1273,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); + EXPECT_TRUE(test::gemm::device::TestAll(1, 1)); } #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_aux_load.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_aux_load.cu new file mode 100644 index 0000000000..c2e2a43ef0 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_aux_load.cu @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 with persistent EVT epilogue + D = alpha * acc + beta * c + aux_load +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_AuxLoadF16_RowMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t + >; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombAuxLoad< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostEVTAuxLoad; + bool passed = test::gemm::device::TestAllEVT(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_AuxLoadF16_ColumnMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::ColumnMajor, cutlass::half_t + >; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombAuxLoad< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostEVTAuxLoad; + bool passed = test::gemm::device::TestAllEVT(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_AuxLoadF32_ColumnMajor) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule + >; + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::ColumnMajor, float + >; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombAuxLoad< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostEVTAuxLoad; + bool passed = test::gemm::device::TestAllEVT(); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu index 16a063af1c..b3af865116 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu @@ -51,7 +51,6 @@ #include "../../common/cutlass_unit_test.h" -#include "testing_elementwise.hpp" #include "gemm_testbed_3x.hpp" @@ -66,8 +65,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedElementwise< - cutlass::epilogue::thread::ReLu>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombEltAct< + cutlass::epilogue::thread::ReLu, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -76,7 +76,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -96,7 +97,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAll(); + test::gemm::device::Testbed3x testbed; + bool passed = test::gemm::device::TestAll(1, 1, testbed); EXPECT_TRUE(passed); } @@ -107,9 +109,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::plus, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -118,7 +120,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -139,7 +142,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } @@ -150,9 +153,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::GELU, cutlass::half_t, cutlass::plus, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::GELU, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -161,7 +164,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -183,7 +187,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; bool check_relative_equality = true; - bool passed = test::gemm::device::TestAllBiasElementwise(check_relative_equality); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1, check_relative_equality); EXPECT_TRUE(passed); } @@ -194,9 +198,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = false; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::plus, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLu, cutlass::half_t, float, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -205,7 +209,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -226,7 +231,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } @@ -238,9 +243,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - test::gemm::device::detail::Negate, cutlass::half_t, cutlass::plus, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::negate, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -249,7 +254,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -270,11 +276,11 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) { +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -282,9 +288,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -293,7 +299,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, cutlass::half_t, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -314,55 +321,11 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_persistent_epilogue, 128 using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - bool passed = test::gemm::device::TestAllBiasElementwise(); - EXPECT_TRUE(passed); -} - -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU) { - - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::RowMajor; - using TileShape_MNK = Shape<_128,_128,_64>; - using ClusterShape_MNK = Shape<_2,_2,_1>; - - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - TileShape_MNK, ClusterShape_MNK, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::half_t, LayoutC, 8, - cutlass::half_t, LayoutC, 8, - EpilogueSchedule - >::CollectiveOp; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::half_t, LayoutA, 8, - cutlass::half_t, LayoutB, 8, - float, - TileShape_MNK, ClusterShape_MNK, - cutlass::gemm::collective::StageCountAutoCarveout, - cutlass::gemm::KernelTmaWarpSpecializedPingpong - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveMainloop, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - bool passed = test::gemm::device::TestAllBiasElementwise(); + bool passed = test::gemm::device::TestAllBiasElementwise(1, 1); EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32Mul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF32_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -370,9 +333,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, float>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -381,7 +344,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, void, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -406,7 +370,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF16Mul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasF16_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -414,9 +378,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, cutlass::half_t>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, cutlass::half_t>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -425,7 +389,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, void, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -450,7 +415,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasS8Mul_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_BiasS8_ReLU_VoidC) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -458,9 +423,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 using TileShape_MNK = Shape<_128,_128,_64>; using ClusterShape_MNK = Shape<_2,_2,_1>; - static constexpr bool StoreT = true; - using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedBiasElementwise< - cutlass::epilogue::thread::ReLu, cutlass::half_t, cutlass::multiplies, StoreT, int8_t>; + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, int8_t>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -469,7 +434,8 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 float, float, void, LayoutC, 8, cutlass::half_t, LayoutC, 8, - EpilogueSchedule + EpilogueSchedule, + FusionOperation >::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< @@ -494,4 +460,4 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } -#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) \ No newline at end of file +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu new file mode 100644 index 0000000000..10b1983e81 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu @@ -0,0 +1,170 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 persistent DAG epilogue + EVTDAG: D = beta * C + Graph(relu(alpha * acc + aux) + aux) + DAGEVT: EVT = alpha * acc + C, D = Graph(maximum(EVT + per-row bias, EVT)) +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_EVTDAG) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using AuxLoadDescriptor = cutlass::epilogue::collective::AuxLoadDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombEVTDAG< + EpilogueDescriptor, AuxLoadDescriptor, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_DAGEVT) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using AuxStoreDescriptor = cutlass::epilogue::collective::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::half_t>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombDAGEVT< + EpilogueDescriptor, AuxStoreDescriptor, cutlass::half_t, float, cutlass::half_t, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_reduce.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_reduce.cu new file mode 100644 index 0000000000..9b0c42a076 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_reduce.cu @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 persistent EVT epilogue + D = row|column|scalar_reduce(alpha * acc + beta * C) +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_RowReduce) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerColumnReduce< + cutlass::plus, cutlass::red, float, TileShape_MNK, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostReduce; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_ColumnReduce) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerRowReduce< + cutlass::plus, cutlass::red, float, TileShape_MNK, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostReduce; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_ScalarReduce) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombScalarReduce< + cutlass::plus, cutlass::red, float, cutlass::half_t, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostReduce; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_row_broadcast.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_row_broadcast.cu new file mode 100644 index 0000000000..7a63657461 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_row_broadcast.cu @@ -0,0 +1,163 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for Sm90 f16_f16_f16 persistent EVT epilogue + D = alpha * acc + beta * C + per_column_bias +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_RowBroadcastF16) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerColumnBias< + EpilogueDescriptor, cutlass::half_t, float, cutlass::half_t, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128x128x64_2x2x1_RowBroadcastF32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::half_t, cutlass::half_t, EpilogueSchedule>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90LinCombPerColumnBias< + EpilogueDescriptor, cutlass::half_t, float, float, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + bool passed = test::gemm::device::TestAllEVT>(); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu new file mode 100644 index 0000000000..f307a682d8 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cooperative_stream_k.cu @@ -0,0 +1,992 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with stream-K scheduling +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_1x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_1x2x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementAccumulator = float; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x2x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 4x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_4x1x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 1x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x64_1x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + + +TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x64_2x4x1) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f16t_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 256x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 128x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 128x128x64_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_stream_k_epilogue, 256x128x64_2x2x1_BiasF32_ReLU) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< + LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu index 5aff82b944..5436e78bfd 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_tensor_broadcast.cu @@ -49,7 +49,6 @@ #include "../../common/cutlass_unit_test.h" #include "gemm_testbed_3x_tensor_broadcast.hpp" -#include "testing_elementwise.hpp" #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) @@ -129,7 +128,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32_tensor_broadcast, 64x128 cutlass::epilogue::thread::ReLu, cutlass::plus, cutlass::plus, - test::gemm::device::detail::Negate + cutlass::negate >, cutlass::gemm::EpilogueDefault>>; @@ -178,7 +177,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16t_tensor_op_gmma_f32_tensor_broadcast, 64x128 cutlass::epilogue::thread::ReLu, cutlass::multiplies, cutlass::plus, - test::gemm::device::detail::Negate + cutlass::negate >, cutlass::gemm::EpilogueDefault>>; @@ -226,7 +225,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_tensor_op_gmma_f32_tensor_broadcast, 128x12 cutlass::epilogue::thread::ReLu, cutlass::epilogue::thread::detail::NoOp, cutlass::plus, - test::gemm::device::detail::Negate + cutlass::negate >, cutlass::gemm::EpilogueDefault>>; @@ -275,7 +274,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f32n_tensor_op_gmma_f32_warpspecialized_tensor_b cutlass::epilogue::thread::ReLu, cutlass::multiplies, cutlass::plus, - test::gemm::device::detail::Negate + cutlass::negate >, cutlass::gemm::EpilogueDefault>>; diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu index 680c4dc20f..e4b92ff938 100644 --- a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu @@ -40,6 +40,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/epilogue/collective/default_epilogue.hpp" #include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" #include "../../common/cutlass_unit_test.h" @@ -86,6 +87,76 @@ TEST(SM90_Device_Gemm_f32t_f32n_f32n_tensor_op_gmma_f32, 64x128x32_1x2x1) { EXPECT_TRUE(test::gemm::device::TestAll()); } +TEST(SM90_Device_Gemm_f32t_f32t_f32n_tensor_op_gmma_f32, 64x128x32_1x1x1_pingpong) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::gemm::EpilogueTransposed + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + float, LayoutA, 4, + float, LayoutB, 4, + float, + Shape<_64,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f32t_f32t_f32n_tensor_op_gmma_f32, 128x128x32_1x1x1_cooperative) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::gemm::EpilogueTransposed + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + float, LayoutA, 4, + float, LayoutB, 4, + float, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + /////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_bf16_tensor_op_fp32.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_bf16_tensor_op_fp32.cu new file mode 100644 index 0000000000..679ec647ca --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_bf16_tensor_op_fp32.cu @@ -0,0 +1,523 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + /*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128_relu) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLu, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128_bias_bf16_relu) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLu, cutlass::bfloat16_t, float, cutlass::bfloat16_t>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// bf16 = e5m2 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e5m2t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e5m2_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e5m2 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e5m2n_bf16n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e5m2_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x2x1 ////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 1x4x1 ////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 4x1x1 ////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x4x1 ////////////////////////////////// +///////////////////////////// bf16 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::bfloat16_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + cutlass::bfloat16_t, LayoutC, 16 / sizeof(cutlass::bfloat16_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// TMA epilogue ///////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32, 64x128x128_tma_epilogue) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16t_tensor_op_gmma_f32, 64x128x128_tma_epilogue_fp8_fast_accum) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_bf16_tensor_op_fp32_evt.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_bf16_tensor_op_fp32_evt.cu new file mode 100644 index 0000000000..a1d3711215 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_bf16_tensor_op_fp32_evt.cu @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for Sm90 f8_f8_bf16 with EVT epilogue + ScaledLinCombPerRowBiasEltAct and ScaledLinCombPerRowBiasEltActAmaxAux +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16t_tensor_op_gmma_f32_epilogue, 64x128x128_ScaledLinCombPerRowBiasEltAct) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_64,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90ScaledLinCombPerRowBiasEltAct< + TileShape_MNK, // CtaTileShapeMNK + cutlass::epilogue::thread::ReLu, // ActivationFn + cutlass::bfloat16_t, // ElementOutput + float, // ElementCompute + cutlass::bfloat16_t // ElementBias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostScaledLinCombPerRowBiasEltAct< + Gemm, cutlass::epilogue::thread::ReLu, cutlass::bfloat16_t + >; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +TEST(SM90_Device_Gemm_e4m3t_e4m3n_bf16n_tensor_op_gmma_f32_epilogue, 64x128x128_4x1x1_ScaledLinCombPerRowBiasEltActAmaxAux) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_64,_128,_128>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::bfloat16_t, cutlass::bfloat16_t, EpilogueSchedule>; + using AuxStoreDescriptor = cutlass::epilogue::collective::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::bfloat16_t>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + TileShape_MNK, // CtaTileShapeMNK + typename EpilogueDescriptor::EpilogueTile, // EpilogueTile + EpilogueDescriptor::StagesD, // StagesD + typename AuxStoreDescriptor::Stride, // StrideAux + typename AuxStoreDescriptor::SmemLayoutAtom, // SmemLayoutAtom + typename AuxStoreDescriptor::CopyOpR2S, // CopyOpR2S + cutlass::epilogue::thread::ReLu, // ActivationFn + cutlass::bfloat16_t, // ElementOutput + float, // ElementCompute + cutlass::bfloat16_t, // ElementBias + float // ElementScalar + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::bfloat16_t, LayoutC, 16, + cutlass::bfloat16_t, LayoutC, 16, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostScaledLinCombPerRowBiasEltActAmaxAux< + Gemm, cutlass::epilogue::thread::ReLu, cutlass::bfloat16_t + >; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu new file mode 100644 index 0000000000..6fd664b985 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu @@ -0,0 +1,533 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + /*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_1x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 256x128x128_1x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_1x2x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 256x128x128_1x2x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 1x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_1x4x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 256x128x128_1x4x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 4x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_4x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 256x128x128_4x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_2x4x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 256x128x128_2x4x1_fp8_fast_accum) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative_evt.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative_evt.cu new file mode 100644 index 0000000000..e560af72d4 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative_evt.cu @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for Sm90 f8_f8_f32 with EVT epilogue + ScaledLinCombPerRowBiasEltAct and ScaledLinCombPerRowBiasEltActAmaxAux +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_epilogue, 128x128x128_1x4x1_ScaledLinCombPerRowBiasEltAct) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90ScaledLinCombPerRowBiasEltAct< + TileShape_MNK, // CtaTileShapeMNK + cutlass::epilogue::thread::ReLu, // ActivationFn + float, // ElementOutput + float, // ElementCompute + float // ElementBias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostScaledLinCombPerRowBiasEltAct< + Gemm, cutlass::epilogue::thread::ReLu, float + >; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 128x128x128_1x2x1_ScaledLinCombPerRowBiasEltActAmaxAux) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, float, float, EpilogueSchedule>; + using AuxStoreDescriptor = cutlass::epilogue::collective::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, float>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + TileShape_MNK, // CtaTileShapeMNK + typename EpilogueDescriptor::EpilogueTile, // EpilogueTile + EpilogueDescriptor::StagesD, // StagesD + typename AuxStoreDescriptor::Stride, // StrideAux + typename AuxStoreDescriptor::SmemLayoutAtom, // SmemLayoutAtom + typename AuxStoreDescriptor::CopyOpR2S, // CopyOpR2S + cutlass::epilogue::thread::ReLu, // ActivationFn + float, // ElementOutput + float, // ElementCompute + float, // ElementBias + float // ElementScalar + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostScaledLinCombPerRowBiasEltActAmaxAux< + Gemm, cutlass::epilogue::thread::ReLu, float + >; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu new file mode 100644 index 0000000000..f879927d22 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_cooperative_stream_k.cu @@ -0,0 +1,544 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + /*! \file + \brief Tests for device-wide GEMM interface with stream-K interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x128_1x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_1x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x128_1x2x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_1x2x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_1,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 1x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x128_1x4x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_1x4x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_1,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 4x1x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x128_4x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_4x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_4,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////// Cluster 2x4x1 //////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 128x128x128_2x4x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative_stream_k, 256x128x128_2x4x1_fp8_fast_accum) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_256,_128,_128>; + using ClusterShape_MNK = Shape<_2,_4,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 16, + ElementB, LayoutB, 16, + ElementAccumulator, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp, + cutlass::gemm::StreamKScheduler + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_fp32.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_fp32.cu new file mode 100644 index 0000000000..38cc6a6d58 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_fp32.cu @@ -0,0 +1,554 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + /*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// FP32 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// FP32 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128_bias_f32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// FP32 = e5m2 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e5m2t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e5m2_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// FP32 = e4m3 * e5m2 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e5m2n_f32n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e5m2_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x2x1 ////////////////////////////////// +///////////////////////////// FP32 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 1x4x1 ////////////////////////////////// +///////////////////////////// FP32 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 4x1x1 ////////////////////////////////// +///////////////////////////// FP32 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x4x1 ////////////////////////////////// +///////////////////////////// FP32 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, float, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 16 / sizeof(float), + float, LayoutC, 16 / sizeof(float), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// TMA epilogue ///////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128_tma_epilogue) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32t_tensor_op_gmma_f32, 64x128x128_tma_epilogue) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32, 64x128x128_tma_epilogue_fp8_fast_accum) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32t_tensor_op_gmma_f32, 64x128x128_tma_epilogue_fp8_fast_accum) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu new file mode 100644 index 0000000000..eb60ad95b4 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu @@ -0,0 +1,1221 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + /*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// output: E4M3 ///////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e5m2 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e5m2t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e5m2_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e5m2 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e5m2n_e4m3n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e5m2_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x2x1 ////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 1x4x1 ////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 4x1x1 ////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x4x1 ////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// output: E5M2 ///////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e5m2 = e5m2 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e5m2t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e5m2_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e5m2 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e5m2n_e5m2n_tensor_op_gmma_f32, 64x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e5m2_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x2x1 ////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x2x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 1x4x1 ////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_1x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 4x1x1 ////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_4x1x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_4,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x4x1 ////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x4x1) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x4x1 ////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x4x1_persistent) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////// Cluster 2x4x1 ////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x4x1_non_warpspecialized) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + + + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////// output: E4M3 + Aux Tensor /////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tensor_e4m3) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutC, cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, cutlass::float_e4m3_t>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +////////////////////////////////// FP8 Accum ///////////////////////////////// +///////////////////////////// e5m2 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x4x1_persistent_fp8_fast_accum) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x4x1_fp8_fast_accum) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e5m2_t, float, float>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + cutlass::float_e5m2_t, LayoutC, 16 / sizeof(cutlass::float_e5m2_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_2,_4,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + + +/////////////////////////////////////////////////////////////////////////////// +////////////////////////// output: E4M3 + Bias /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_bias_bf16) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, cutlass::float_e4m3_t, float, cutlass::bfloat16_t>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + + +/////////////////////////////////////////////////////////////////////////////// +////////////////////////// output: E4M3 + Bias + Relu //////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e4m3 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_bias_bf16_relu) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::ReLu, cutlass::float_e4m3_t, float, cutlass::bfloat16_t>; + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////// output: E4M3 + Aux Tensor + Bias///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e5m2 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e5m2n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tensor_f16_bias_f16) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutC, cutlass::epilogue::thread::Identity, + cutlass::float_e4m3_t, // ElementOutput + float, // ElementCompute + cutlass::half_t, // ElementAux + float, // ElementAmax + cutlass::half_t>; // ElementBias + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e5m2_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////// output: E4M3 + Aux Tensor + Bias + Relu///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e5m2 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e5m2n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tensor_f16_relu) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutC, cutlass::epilogue::thread::ReLu, + cutlass::float_e4m3_t, // ElementOutput + float, // ElementCompute + cutlass::half_t, // ElementAux + float, // ElementAmax + float>; // ElementBias + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e5m2_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////// e4m3 = e4m3 * e5m2 (TN) ///////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e5m2n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tensor_f16_bias_f16_relu) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutC, cutlass::epilogue::thread::ReLu, + cutlass::float_e4m3_t, // ElementOutput + float, // ElementCompute + cutlass::half_t, // ElementAux + float, // ElementAmax + cutlass::half_t>; // ElementBias + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + cutlass::float_e4m3_t, LayoutC, 16 / sizeof(cutlass::float_e4m3_t), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e5m2_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////// TMA epilogue ///////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_tma_epilogue) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16, + cutlass::float_e4m3_t, LayoutC, 16, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3t_tensor_op_gmma_f32, 64x128x128_tma_epilogue) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16, + cutlass::float_e4m3_t, LayoutC, 16, + cutlass::epilogue::TmaWarpSpecialized + >::CollectiveOp; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + Shape<_64,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32_evt.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32_evt.cu new file mode 100644 index 0000000000..5738a2ac67 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32_evt.cu @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for Sm90 f8_f8_bf16 with EVT epilogue + ScaledLinCombPerRowBiasEltAct and ScaledLinCombPerRowBiasEltActAmaxAux +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x_evt.hpp" +#include "sm90_evt_operations.hpp" + + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +TEST(SM90_Device_Gemm_f8t_f8n_f8t_tensor_op_gmma_f32_persistent_epilogue, 64x128x128_1x1x1_ScaledLinCombPerRowBiasEltAct) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_64,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionCallbacks = cutlass::epilogue::fusion::Sm90ScaledLinCombPerRowBiasEltAct< + TileShape_MNK, // CtaTileShapeMNK + cutlass::epilogue::thread::ReLu, // ActivationFn + cutlass::float_e4m3_t, // ElementOutput + float, // ElementCompute + cutlass::float_e4m3_t // ElementBias + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::float_e4m3_t, LayoutC, 16, + cutlass::float_e4m3_t, LayoutC, 16, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostScaledLinCombPerRowBiasEltAct< + Gemm, cutlass::epilogue::thread::ReLu, cutlass::float_e4m3_t + >; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +TEST(SM90_Device_Gemm_f8t_f8n_f8t_tensor_op_gmma_f32_persistent_epilogue, 64x128x128_1x1x1_ScaledLinCombPerRowBiasEltActAmaxAux) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_64,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using EpilogueDescriptor = cutlass::epilogue::collective::EpilogueDescriptor< + TileShape_MNK, EpilogueTileType, cutlass::float_e4m3_t, cutlass::float_e4m3_t, EpilogueSchedule>; + using AuxStoreDescriptor = cutlass::epilogue::collective::AuxStoreDescriptor< + EpilogueDescriptor, cutlass::layout::RowMajor, cutlass::float_e4m3_t>; + + using FusionCallbacks = cutlass::epilogue::fusion::Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + TileShape_MNK, // CtaTileShapeMNK + typename EpilogueDescriptor::EpilogueTile, // EpilogueTile + EpilogueDescriptor::StagesD, // StagesD + typename AuxStoreDescriptor::Stride, // StrideAux + typename AuxStoreDescriptor::SmemLayoutAtom, // SmemLayoutAtom + typename AuxStoreDescriptor::CopyOpR2S, // CopyOpR2S + cutlass::epilogue::thread::ReLu, // ActivationFn + cutlass::float_e4m3_t, // ElementOutput + float, // ElementCompute + cutlass::float_e4m3_t, // ElementBias + float // ElementScalar + >; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + EpilogueTileType, + float, float, + cutlass::float_e4m3_t, LayoutC, 16, + cutlass::float_e4m3_t, LayoutC, 16, + EpilogueSchedule, + FusionCallbacks + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::float_e4m3_t, LayoutA, 16, + cutlass::float_e4m3_t, LayoutB, 16, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Host reference + using HostReference = test::gemm::device::HostScaledLinCombPerRowBiasEltActAmaxAux< + Gemm, cutlass::epilogue::thread::ReLu, cutlass::float_e4m3_t + >; + bool passed = test::gemm::device::TestAllEVT(true); + EXPECT_TRUE(passed); +} +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu new file mode 100644 index 0000000000..8e6a40d52a --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu @@ -0,0 +1,277 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests that the stream-K scheduler covers the entire problem space. +*/ + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "../../common/cutlass_unit_test.h" + +using namespace cute; +using ProblemShape_MNKL = Shape; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel for getting each piece of work for a given block from the scheduler and logging +/// the K iterations visited by the block. +template < + class Scheduler, + class TileShape, + class ClusterShape +> +__global__ +void +run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape tile_shape, ClusterShape cluster_shape, ProblemShape_MNKL problem_shape_mnkl) { + Scheduler scheduler{params}; + auto work_tile_info = scheduler.get_current_work(); + + while (work_tile_info.is_valid_tile) { + // Increment counters to indicate coverage + auto tile_idx = Scheduler::output_tile_index(params, work_tile_info); + auto offset = tile_idx * params.k_iter_per_tile_ + work_tile_info.K_idx; + for (auto i = 0; i < work_tile_info.k_tile_count; ++i) { + // Use atomicAdd because the visit counters are shared by multiple thread blocks. + // While having more than one block increment the same counter indicates failure, + // we need to ensure that this behavior is captured (by having both increments reflected). + atomicAdd(visit_counters + offset + i, 1); + } + + bool continue_current = scheduler.continue_current_work(work_tile_info); + if (!continue_current) { + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } + } +} + +/// Host-side wrapper for launching the kernel to test the scheduler. +template < + class TileShape, + class ClusterShape, + uint32_t NumMmaWarpGroups = 2 +> +bool +test_scheduler( + ProblemShape_MNKL problem_shape_mnkl, + TileShape tile_shape, + ClusterShape cluster_shape, + int sm_count, + int splits=1, + bool expect_data_parallel=false) { + + using Scheduler = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamK; + + cutlass::KernelHardwareInfo hw_info{0, sm_count}; + auto params = Scheduler::to_underlying_arguments(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, {splits}, nullptr); + + // If we expect the schedule to be data-parallel only, ensure that no stream-K tiles are launched. + if (expect_data_parallel && params.sk_tiles_ != 0) { + return false; + } + + // Allocate counters indicating the number of times each k iteration of each output tile has been visited + auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); + auto total_counters = blk_m * blk_n * blk_l * params.k_iter_per_tile_; + cutlass::DeviceAllocation visit_counters(total_counters); + + // Initialize counters to zero + cudaError_t err = cudaMemset((void*)visit_counters.get(), 0, sizeof(int) * total_counters); + if (err != cudaSuccess) { + std::cerr << __FILE__ << ":" << __LINE__ << " cudaMemset failed with error: " << cudaGetErrorString(err) << std::endl; + return false; + } + + typename Scheduler::Arguments args{}; + + // Set up the grid for the problem + dim3 grid = Scheduler::get_grid_shape(problem_shape_mnkl, tile_shape, cluster_shape, hw_info, args); + + // Run the scheduler to completion and log visits to each k iteration + run_scheduler<<>>( + visit_counters.get(), params, tile_shape, cluster_shape, problem_shape_mnkl); + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << __FILE__ << ":" << __LINE__ << " scheduler kernel failed with error: " << cudaGetErrorString(err) << std::endl; + return false; + } + + // Copy visit counts back to host and ensure that all entries are ones + std::vector host_visit_counts(total_counters); + visit_counters.copy_to_host(host_visit_counts.data()); + + for (size_t i = 0; i < host_visit_counts.size(); ++i) { + if (host_visit_counts[i] != 1) { + // for (int count : host_visit_counts) { + // if (count != 1) { + std::cout << "Failed with problem size " + << size<0>(problem_shape_mnkl) << "x" + << size<1>(problem_shape_mnkl) << "x" + << size<2>(problem_shape_mnkl) << "x" + << size<3>(problem_shape_mnkl) + << " and grid size " << grid.x << "x" + << grid.y << "x" << grid.z + << " splits=" << params.splits_ + << " k_iter=" << params.k_iter_per_tile_ + << " big_units=" << params.big_units_ + << " sk_tiles=" << params.sk_tiles_ + << " sk_units=" << params.sk_units_ + << " k_iter_per_sk_unit=" << params.k_iter_per_sk_unit_ << std::endl; + std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl; + return false; + } + } + + return true; +} + +/// Executes tests of the scheduler with a sweep across problem size K +template < + class TileShape, + class ClusterShape +> +bool sweep_k( + ProblemShape_MNKL problem_shape_mnkl, + TileShape tile_shape, + ClusterShape cluster_shape, + int sm_count, + int splits=1, + bool expect_data_parallel=false, + int k_start=128, + int k_stop=16384, + int k_step=0) { + + if (k_step == 0) { + k_step = 4 * cute::size<2>(tile_shape); + } + + for (int k = k_start; k <= k_stop; k += k_step) { + ProblemShape_MNKL problem{get<0>(problem_shape_mnkl), get<1>(problem_shape_mnkl), k, get<3>(problem_shape_mnkl)}; + bool passed = test_scheduler(problem, tile_shape, cluster_shape, sm_count, splits, expect_data_parallel); + if (!passed) { + return false; + } + } + + return true; +} + +/// Executes tests of the scheduler that are expected to result in a data-parallel schedule. +/// This function assumes that the problem, tile, and cluster shape, alongside the SM count, +/// are such that the problem executes only full waves on the device. +template < + class TileShape, + class ClusterShape +> +bool test_data_parallel( + int blocks_m, + int blocks_n, + TileShape tile_shape, + ClusterShape cluster_shape, + int sm_count) { + + // Since the configuration passed in executes only full waves, increasing + // the batch dimension simply results in running more full waves. + for (int l = 1; l < 4; ++l) { + ProblemShape_MNKL problem_shape{ + size<0>(tile_shape) * blocks_m, size<1>(tile_shape) * blocks_n, 1, l}; + bool passed = sweep_k(problem_shape, tile_shape, cluster_shape, sm_count, /*splits=*/1, /*expect_data_parallel=*/true); + + if (!passed) { + return false; + } + } + return true; +} + +/// Executes tests of the scheduler on the generic stream-K decomposition. +template < + class TileShape, + class ClusterShape +> +bool test_stream_k( + TileShape tile_shape, + ClusterShape cluster_shape, + int sm_count) { + + int tile_m = size<0>(tile_shape); + int tile_n = size<1>(tile_shape); + + for (int m_blocks = 1; m_blocks <= 24; ++m_blocks) { + for (int n_blocks = 1; n_blocks <= 24; ++n_blocks) { + for (int l = 1; l < 4; ++l) { + ProblemShape_MNKL problem{m_blocks * tile_m, n_blocks * tile_n, 1, l}; + if (!sweep_k(problem, tile_shape, cluster_shape, sm_count)) { + return false; + } + } + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_stream_k_scheduler, 256x128x64_2x1x1) { + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_1,_1>; + + TileShape_MNK tile_shape; + ClusterShape_MNK cluster_shape; + + // Test various data-parallel cases + EXPECT_TRUE(test_data_parallel(/*blocks_m=*/ 4, /*blocks_n=*/ 4, tile_shape, cluster_shape, /*sm_count=*/ 16)); + EXPECT_TRUE(test_data_parallel(/*blocks_m=*/16, /*blocks_n=*/ 4, tile_shape, cluster_shape, /*sm_count=*/ 64)); + EXPECT_TRUE(test_data_parallel(/*blocks_m=*/ 4, /*blocks_n=*/27, tile_shape, cluster_shape, /*sm_count=*/108)); + + // Test various stream-K cases + EXPECT_TRUE(test_stream_k(tile_shape, cluster_shape, /*sm_count=*/ 16)); + EXPECT_TRUE(test_stream_k(tile_shape, cluster_shape, /*sm_count=*/ 64)); + EXPECT_TRUE(test_stream_k(tile_shape, cluster_shape, /*sm_count=*/108)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_stream_k_scheduler, 128x128x64_2x1x1) { + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_1,_1>; + + TileShape_MNK tile_shape; + ClusterShape_MNK cluster_shape; + + EXPECT_TRUE(test_scheduler({128, 512, 2048, 1}, tile_shape, cluster_shape, 114)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu index a13f744a64..904b4bf39c 100644 --- a/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/symm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -49,7 +49,6 @@ #include "testbed_symm_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Symm_cf64n_cf64n_ls_l_tensor_op_f64_gaussian, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu index 1feb2d67bb..fff373e23d 100644 --- a/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/symm_f64_f64_tensor_op_f64_sm90.cu @@ -48,7 +48,6 @@ #include "testbed_symm_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Symm_f64n_f64n_rs_l_tensor_op_f64, 32x32x16_16x16x16) { @@ -132,4 +131,5 @@ TEST(SM90_Device_Symm_f64t_f64t_ls_l_tensor_op_f64, 128x128x16_32x64x16) { } ///////////////////////////////////////////////////////////////////////////////////////////////// + #endif // #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu index 76d19f650d..cb9f419ded 100644 --- a/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syr2k_cf64_cf64_tensor_op_f64_sm90.cu @@ -48,7 +48,6 @@ #include "testbed_rank2k_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Syr2k_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu index f7aa84db3b..7af7a25c05 100644 --- a/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syr2k_f64_f64_tensor_op_f64_sm90.cu @@ -48,7 +48,6 @@ #include "testbed_rank2k_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Syr2k_f64n_f64n_l_tensor_op_f64, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu index 98da67d310..37a8a42ab5 100644 --- a/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syrk_cf64_cf64_tensor_op_f64_sm90.cu @@ -48,7 +48,6 @@ #include "testbed_rank_k_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Syrk_cf64n_cf64n_l_tensor_op_f64, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu index 8fe762775d..b98a58904f 100644 --- a/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/syrk_f64_f64_tensor_op_f64_sm90.cu @@ -48,7 +48,6 @@ #include "testbed_rank_k_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Syrk_f64n_f64t_l_tensor_op_f64, 128x64x16_64x32x16) { diff --git a/test/unit/gemm/device/testbed.h b/test/unit/gemm/device/testbed.h index dc21f41fe9..3e3178cbc3 100644 --- a/test/unit/gemm/device/testbed.h +++ b/test/unit/gemm/device/testbed.h @@ -159,7 +159,6 @@ struct Testbed { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/test/unit/gemm/device/testbed_gemm_with_broadcast.h index c28fc8dd81..9336874bae 100644 --- a/test/unit/gemm/device/testbed_gemm_with_broadcast.h +++ b/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -186,7 +186,6 @@ struct TestbedGemmWithBroadcast { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/testbed_gemm_with_reduction.h b/test/unit/gemm/device/testbed_gemm_with_reduction.h index 6f220b1eb1..5c3e7353e7 100644 --- a/test/unit/gemm/device/testbed_gemm_with_reduction.h +++ b/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -178,7 +178,6 @@ struct TestbedGemmWithReduction { } } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/testbed_interleaved.h b/test/unit/gemm/device/testbed_interleaved.h index b54a4b6b8e..98d5af9c62 100644 --- a/test/unit/gemm/device/testbed_interleaved.h +++ b/test/unit/gemm/device/testbed_interleaved.h @@ -103,7 +103,6 @@ struct InterleavedTestbed { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/testbed_sparse.h b/test/unit/gemm/device/testbed_sparse.h index ee513da12f..1e521ea7be 100644 --- a/test/unit/gemm/device/testbed_sparse.h +++ b/test/unit/gemm/device/testbed_sparse.h @@ -155,7 +155,6 @@ struct SparseTestbed { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } @@ -173,10 +172,10 @@ struct SparseTestbed { tensor_B.resize(problem_size.kn()); if (tensor_C_row_broadcast) { tensor_C.resize({problem_size.m(), 1}); - } - else { + } else { tensor_C.resize(problem_size.mn()); } + tensor_D.resize(problem_size.mn()); reference_D.resize(problem_size.mn(), false); tensor_E.resize(cutlass::make_Coord( @@ -197,7 +196,6 @@ struct SparseTestbed { cutlass::reference::host::TensorFill(tensor_E.host_view(), (ElementE)(content)); } else { - // TODO: Implement the rest EXPECT_TRUE(false); } @@ -215,8 +213,7 @@ struct SparseTestbed { for (int i = 0; i < problem_size.m(); ++i) for (int j = 0; j < problem_size.n(); ++j) reference_D.host_view().at({i, j}) = tensor_C.host_view().at({i, 0}); - } - else { + } else { cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); } diff --git a/test/unit/gemm/device/testbed_trmm_universal.h b/test/unit/gemm/device/testbed_trmm_universal.h index db40eff767..422d58b9a0 100644 --- a/test/unit/gemm/device/testbed_trmm_universal.h +++ b/test/unit/gemm/device/testbed_trmm_universal.h @@ -140,7 +140,6 @@ struct TestbedTrmmUniversal { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } @@ -186,7 +185,6 @@ struct TestbedTrmmUniversal { view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } @@ -230,7 +228,6 @@ struct TestbedTrmmUniversal { EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/testbed_universal.h b/test/unit/gemm/device/testbed_universal.h index 615e9c5c6b..a849b593b9 100644 --- a/test/unit/gemm/device/testbed_universal.h +++ b/test/unit/gemm/device/testbed_universal.h @@ -135,7 +135,6 @@ struct TestbedUniversal { view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu index 437bed55b4..ba8084d3ac 100644 --- a/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/trmm_cf64_cf64_cf64_tensor_op_f64_sm90.cu @@ -49,7 +49,6 @@ #include "testbed_trmm_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Trmm_cf64n_cf64n_cf64t_ls_u_nu_tensor_op_f64_gaussian, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu b/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu index 5339bc556b..ffcefafefc 100644 --- a/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu +++ b/test/unit/gemm/device/trmm_f64_f64_f64_tensor_op_f64_sm90.cu @@ -49,7 +49,6 @@ #include "testbed_trmm_universal.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(SM90_Device_Trmm_f64n_f64n_f64t_rs_l_nu_tensor_op_f64, 32x32x16_16x16x16) { diff --git a/test/unit/gemm/threadblock/mma_multistage.cu b/test/unit/gemm/threadblock/mma_multistage.cu index 8025637a1b..1313b1abaa 100644 --- a/test/unit/gemm/threadblock/mma_multistage.cu +++ b/test/unit/gemm/threadblock/mma_multistage.cu @@ -2838,7 +2838,6 @@ TEST(SM80_gemm_threadblock_crosswise, } //////////////////////////////////////////////////////////////////////////////// - TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x1024_64x64x1024_16x8x256_3stage) { using ElementA = cutlass::uint1b_t; @@ -3328,7 +3327,6 @@ TEST(SM80_gemm_threadblock_crosswise, } //////////////////////////////////////////////////////////////////////////////// - TEST(SM80_gemm_threadblock_congruous, tensor_op_64x64x16_32x64x16_8x8x4_3stage) { using ElementA = double; diff --git a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h index 6e14745eb5..44ef05305e 100644 --- a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h +++ b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -296,7 +296,6 @@ struct SparseTestbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -322,7 +321,6 @@ struct SparseTestbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); } else { - // TODO: Implement the rest return false; } @@ -339,7 +337,6 @@ struct SparseTestbed { cutlass::reference::host::TensorFill(matrix_E.host_view(), (ElementE)(content)); } else { - // TODO: Implement the rest return false; } diff --git a/test/unit/gemm/threadblock/mma_multistage_testbed.h b/test/unit/gemm/threadblock/mma_multistage_testbed.h index 1e859b6184..bda862c828 100644 --- a/test/unit/gemm/threadblock/mma_multistage_testbed.h +++ b/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -253,7 +253,6 @@ struct Testbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -279,7 +278,6 @@ struct Testbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); } else { - // TODO: Implement the rest return false; } diff --git a/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h index a47a30024a..8810f5faa3 100644 --- a/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h +++ b/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h @@ -244,7 +244,6 @@ struct Testbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -270,7 +269,6 @@ struct Testbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); } else { - // TODO: Implement the rest return false; } diff --git a/test/unit/gemm/threadblock/mma_pipelined_sm75.cu b/test/unit/gemm/threadblock/mma_pipelined_sm75.cu index 3f173873d2..9c4f95d0c8 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_sm75.cu +++ b/test/unit/gemm/threadblock/mma_pipelined_sm75.cu @@ -1793,7 +1793,6 @@ TEST(SM75_gemm_threadblock_interleaved, } //////////////////////////////////////////////////////////////////////////////// - TEST(SM75_gemm_threadblock_crosswise, tensor_op_64x64x512_64x64x512_8x8x128) { using ElementA = cutlass::uint1b_t; using LayoutA = cutlass::layout::RowMajor; diff --git a/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/test/unit/gemm/threadblock/mma_pipelined_testbed.h index 6f36b53e71..ac088c28aa 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_testbed.h +++ b/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -262,7 +262,6 @@ struct Testbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -288,7 +287,6 @@ struct Testbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); } else { - // TODO: Implement the rest return false; } diff --git a/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h index 9e8d351416..688514ca5d 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h +++ b/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h @@ -250,7 +250,6 @@ struct Testbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -276,7 +275,6 @@ struct Testbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); } else { - // TODO: Implement the rest return false; } diff --git a/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu b/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu index 12fae1f8a9..54d0e930b0 100644 --- a/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu +++ b/test/unit/gemm/threadblock/mma_pipelined_wmma_sm75.cu @@ -262,7 +262,6 @@ TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_col_s4, 64x64x64_64x64x64_8x8x problem_size.k(), alpha, beta) .run(grid, block); } - TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_row_b1, 64x64x512_64x64x512_8x8x128) { using ElementA = cutlass::uint1b_t; using LayoutA = cutlass::layout::RowMajor; @@ -332,6 +331,7 @@ TEST(SM75_gemm_threadblock_wmma_tensor_op_row_col_col_b1, 64x64x512_64x64x512_8x problem_size.k(), alpha, beta) .run(grid, block); } + #endif //CUTLASS_SUBBYTE_INTEGER_MATRIX_MULTIPLY_ENABLED #endif //CUTLASS_ARCH_WMMA_SM75_ENABLED diff --git a/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h index b33abdb623..45b50fac9e 100644 --- a/test/unit/gemm/threadblock/mma_planar_complex_testbed.h +++ b/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -232,7 +232,6 @@ struct TestbedPlanarComplex { } else if (init_A == cutlass::Distribution::Identity) { //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -270,7 +269,6 @@ struct TestbedPlanarComplex { //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); } else { - // TODO: Implement the rest return false; } diff --git a/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu b/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu index 1b24ebe5ae..1c64422084 100644 --- a/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu +++ b/test/unit/gemm/threadblock/mma_singlestage_wmma_sm75.cu @@ -262,7 +262,6 @@ TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_col_s4, 64x64x64_6 problem_size.k(), alpha, beta) .run(grid, block); } - TEST(SM75_gemm_threadblock_singlestage_wmma_tensor_op_row_col_row_b1, 64x64x512_64x64x512_8x8x128) { using ElementA = cutlass::uint1b_t; using LayoutA = cutlass::layout::RowMajor; diff --git a/test/unit/gemm/warp/gemm_complex_sm90.cu b/test/unit/gemm/warp/gemm_complex_sm90.cu index 38bdfa65d8..a1707de1bd 100644 --- a/test/unit/gemm/warp/gemm_complex_sm90.cu +++ b/test/unit/gemm/warp/gemm_complex_sm90.cu @@ -51,7 +51,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - TEST(SM90_warp_gemm_complex_tensor_op_f64, 16x8x4_16x8x4_nt) { using Shape = cutlass::gemm::GemmShape<16, 8, 4>; @@ -330,5 +329,4 @@ TEST(SM90_warp_gemm_complex_tensor_op_f64, 64x64x4_16x8x4_tn) { test::gemm::warp::TestbedComplex().run(); } - #endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/warp/gemm_sm75.cu b/test/unit/gemm/warp/gemm_sm75.cu index 43f185d934..7ab4b21a9c 100644 --- a/test/unit/gemm/warp/gemm_sm75.cu +++ b/test/unit/gemm/warp/gemm_sm75.cu @@ -746,7 +746,6 @@ TEST(SM75_warp_gemm_tensor_op_interleaved_i4, 128x128x128_16x16x128_8x8x32) { } //////////////////////////////////////////////////////////////////////////////// - TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x64x512_8x8x128) { using Shape = cutlass::gemm::GemmShape<64, 64, 512>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 128>; @@ -856,5 +855,4 @@ TEST(SM75_warp_gemm_tensor_op_crosswise_b1, 128x128x512_16x16x512_8x8x128) { } //////////////////////////////////////////////////////////////////////////////// - #endif diff --git a/test/unit/gemm/warp/gemm_sm80.cu b/test/unit/gemm/warp/gemm_sm80.cu index 54a0248e8a..4034767d69 100644 --- a/test/unit/gemm/warp/gemm_sm80.cu +++ b/test/unit/gemm/warp/gemm_sm80.cu @@ -1316,7 +1316,6 @@ TEST(SM80_warp_gemm_tensor_op_crosswise_i4, 128x128x256_16x16x256_16x8x64) { } //////////////////////////////////////////////////////////////////////////////// - TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x512_64x64x512_16x8x256) { using Shape = cutlass::gemm::GemmShape<64, 64, 512>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; @@ -1526,7 +1525,6 @@ TEST(SM80_warp_gemm_tensor_op_crosswise_b1, 128x128x1024_16x16x1024_16x8x256) { } //////////////////////////////////////////////////////////////////////////////// - TEST(SM80_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_8x8x4) { using Shape = cutlass::gemm::GemmShape<16, 16, 4>; using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; diff --git a/test/unit/gemm/warp/gemm_sm90.cu b/test/unit/gemm/warp/gemm_sm90.cu index f417a41fcc..6c2cc78bbc 100644 --- a/test/unit/gemm/warp/gemm_sm90.cu +++ b/test/unit/gemm/warp/gemm_sm90.cu @@ -51,7 +51,6 @@ #include "testbed.h" #if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) - TEST(SM90_warp_gemm_tensor_op_congruous_f64, 16x16x4_16x16x4_16x8x4) { using Shape = cutlass::gemm::GemmShape<16, 16, 4>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; @@ -202,5 +201,4 @@ TEST(SM90_warp_gemm_tensor_op_crosswise_f64, 32x64x16_32x64x16_16x8x4) { .run(); } //////////////////////////////////////////////////////////////////////////////// - #endif // if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) diff --git a/test/unit/gemm/warp/testbed.h b/test/unit/gemm/warp/testbed.h index 3487aa0ffd..fe62ce4446 100644 --- a/test/unit/gemm/warp/testbed.h +++ b/test/unit/gemm/warp/testbed.h @@ -259,7 +259,6 @@ struct Testbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -286,7 +285,6 @@ struct Testbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); } else { - // TODO: Implement the rest return false; } @@ -492,7 +490,6 @@ struct TestbedComplex { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -506,7 +503,6 @@ struct TestbedComplex { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); } else { - // TODO: Implement the rest return false; } @@ -814,7 +810,6 @@ struct TransformTestbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -839,7 +834,6 @@ struct TransformTestbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); } else { - // TODO: Implement the rest return false; } @@ -1041,7 +1035,6 @@ struct TransformedTestbedComplex { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -1055,7 +1048,6 @@ struct TransformedTestbedComplex { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); } else { - // TODO: Implement the rest return false; } @@ -1410,7 +1402,6 @@ struct SparseTestbed { } else if (init_A == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); } else { - // TODO: Implement the rest return false; } @@ -1435,7 +1426,6 @@ struct SparseTestbed { } else if (init_B == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); } else { - // TODO: Implement the rest return false; } @@ -1463,7 +1453,6 @@ struct SparseTestbed { cutlass::reference::host::TensorFill(tensor_E.host_view(), (ElementE)(content)); } else { - // TODO: Implement the rest return false; } diff --git a/test/unit/gemm/warp/wmma_sm75.cu b/test/unit/gemm/warp/wmma_sm75.cu index ebc0f3b04a..81a98d4c1d 100644 --- a/test/unit/gemm/warp/wmma_sm75.cu +++ b/test/unit/gemm/warp/wmma_sm75.cu @@ -52,7 +52,6 @@ #include "cutlass/util/reference/host/gemm.h" #include "testbed.h" - /////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////// SUBBYTE wmma.mma //////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/nvrtc/CMakeLists.txt b/test/unit/nvrtc/CMakeLists.txt index 112d25575b..c76581f9c0 100644 --- a/test/unit/nvrtc/CMakeLists.txt +++ b/test/unit/nvrtc/CMakeLists.txt @@ -56,8 +56,13 @@ endmacro() string(APPEND NVRTC_INCLUDES_STRINGS "char const *kCutlassHeaders[] = {\n") string(APPEND NVRTC_INCLUDES_NAMES "char const *kCutlassHeaderNames[] = {\n") + +file(GLOB_RECURSE NVRTC_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} kernel/thread/*.hpp) + add_nvrtc_headers(${PROJECT_SOURCE_DIR}/include "${CUTLASS_CUTLASS};${CUTLASS_UTIL};${CUTLASS_DEVICE}") +add_nvrtc_headers(${PROJECT_SOURCE_DIR}/include "${CUTLASS_CUTE}") add_nvrtc_headers(${PROJECT_SOURCE_DIR}/test "${CUTLASS_NVRTC};${CUTLASS_UTIL};${CUTLASS_DEVICE}") +add_nvrtc_headers(${CMAKE_CURRENT_SOURCE_DIR} "${NVRTC_SOURCES}") add_nvrtc_headers("${CMAKE_CURRENT_SOURCE_DIR}/stdlib" "assert.h;stdint.h") if(CUTLASS_NVRTC_HAS_CUDA_FP16) diff --git a/test/unit/nvrtc/kernel/thread/contraction.hpp b/test/unit/nvrtc/kernel/thread/contraction.hpp new file mode 100644 index 0000000000..65c4437a2e --- /dev/null +++ b/test/unit/nvrtc/kernel/thread/contraction.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + + +namespace nvrtc { +namespace thread { + +template< + typename ElementA, typename ElementB, typename ElementC, + typename TileShape, typename ClusterShape, + bool kTransA, bool kTransB, + int RANK_M, int RANK_N, int RANK_K, int RANK_L +> +struct ContractionKernel { + +using ElementScalar = float; +using ElementAccum = float; +using EpilogueThread = cutlass::epilogue::thread::LinearCombination; + +static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K; +static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + +/// Kernel config +typedef int64_t stride_type; +typedef int32_t extent_type; + +static constexpr const stride_type* stride_null = nullptr; +static constexpr const extent_type* extent_null = nullptr; + +template +static constexpr +auto +make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { + static_assert(Rank > 1); + if constexpr (IsMajor) { + return cute::transform(cute::make_seq{}, [&](auto i) { + if constexpr (i == 0) { + return cute::Int<1>{}; + } + else { + return i < n ? t[i] : init_default; + } + }); + } + else { + return cute::make_int_tuple(t, n, init_default); + } +} + +using StrideA = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideB = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideC = decltype(cute::make_stride( + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using ProblemShape = decltype(cute::make_shape( + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0))); + +using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 16 / sizeof(ElementA), + ElementB, StrideB, 16 / sizeof(ElementB), + ElementAccum, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized +>::CollectiveOp; + +using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue; +using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter; +using Kernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveOp, + CollectiveEpilogue>; + +}; + +} // namespace nvrtc +} // namespace thread diff --git a/test/unit/nvrtc/thread/.gitignore b/test/unit/nvrtc/thread/.gitignore new file mode 100644 index 0000000000..9484314ac8 --- /dev/null +++ b/test/unit/nvrtc/thread/.gitignore @@ -0,0 +1 @@ +nvrtc_config.hpp diff --git a/test/unit/nvrtc/thread/CMakeLists.txt b/test/unit/nvrtc/thread/CMakeLists.txt index 59e5d0082d..e164604374 100644 --- a/test/unit/nvrtc/thread/CMakeLists.txt +++ b/test/unit/nvrtc/thread/CMakeLists.txt @@ -26,10 +26,15 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +configure_file(nvrtc_config.in nvrtc_config.hpp) + cutlass_test_unit_add_executable( cutlass_test_unit_nvrtc_thread - gemm_nvrtc.cu + nvrtc_gemm.cu + nvrtc_contraction.cu testbed.h - ) +) target_link_libraries(cutlass_test_unit_nvrtc_thread PRIVATE cutlass_nvrtc) + +target_include_directories(cutlass_test_unit_nvrtc_thread PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) diff --git a/test/unit/nvrtc/thread/nvrtc_config.in b/test/unit/nvrtc/thread/nvrtc_config.in new file mode 100644 index 0000000000..6291b93bd5 --- /dev/null +++ b/test/unit/nvrtc/thread/nvrtc_config.in @@ -0,0 +1,3 @@ +#pragma once + +#define CUDA_INCLUDE_DIR "@CUDA_TOOLKIT_ROOT_DIR@/include" diff --git a/test/unit/gemm/device/testing_elementwise.hpp b/test/unit/nvrtc/thread/nvrtc_contraction.cu similarity index 60% rename from test/unit/gemm/device/testing_elementwise.hpp rename to test/unit/nvrtc/thread/nvrtc_contraction.cu index a2d5b3ea02..934523b532 100644 --- a/test/unit/gemm/device/testing_elementwise.hpp +++ b/test/unit/nvrtc/thread/nvrtc_contraction.cu @@ -29,53 +29,38 @@ * **************************************************************************************************/ /*! \file - \brief Elementwise activation functors used only for testing purposes. + \brief Unit tests for GETT */ -#pragma once +#include +#include -#include -#include -#include +#include "testbed.h" -#include "../../common/cutlass_unit_test.h" +#include "nvrtc_config.hpp" -#include "cutlass/util/host_tensor.h" -#include "cutlass/util/tensor_view_io.h" -#include "cutlass/util/distribution.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/reference/host/tensor_fill.h" -#include "cutlass/util/reference/host/tensor_copy.h" -#include "cutlass/util/reference/host/tensor_compare.h" -#include "cutlass/util/reference/host/tensor_norm.h" -#include "cutlass/util/reference/host/gett.hpp" +#ifndef CUDA_INCLUDE_DIR +static_assert(0, "CUDA include path is not defined"); +#endif -#include "testbed_utils.h" +TEST(SM90_nvrtc_kernel, Contraction) { + static const char* nvrtc_opts[] = { + "-w", + "-default-device", + "-std=c++17", + "-arch=sm_90", + "-I" CUDA_INCLUDE_DIR, + }; -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/layout/matrix.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/gemm/gemm.h" + EXPECT_TRUE(test::nvrtc::thread::TestbedKernel::compile( + "nvrtc::thread::ContractionKernel<" + "cutlass::bfloat16_t, cutlass::bfloat16_t, cutlass::bfloat16_t," + "cute::Shape, cute::Shape, cute::Shape>," + "cute::Shape," + "true, true," + "10, 10, 10, 10>::Kernel", + { nvrtc_opts, nvrtc_opts + 5 } + )); +} -#include "cute/int_tuple.hpp" - -namespace test { -namespace gemm { -namespace device { -namespace detail{ - -/// Simple activation function that negates the input. -template -struct Negate { - static constexpr T neg_one = T(-1); - - CUTLASS_HOST_DEVICE - T operator()(const T& data) { - return data * neg_one; - } -}; - -} // namespace detail -} // namespace device -} // namespace gemm -} // namespace test +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/nvrtc/thread/gemm_nvrtc.cu b/test/unit/nvrtc/thread/nvrtc_gemm.cu similarity index 100% rename from test/unit/nvrtc/thread/gemm_nvrtc.cu rename to test/unit/nvrtc/thread/nvrtc_gemm.cu diff --git a/test/unit/nvrtc/thread/testbed.h b/test/unit/nvrtc/thread/testbed.h index 378be81da0..0b8d3bd763 100644 --- a/test/unit/nvrtc/thread/testbed.h +++ b/test/unit/nvrtc/thread/testbed.h @@ -35,12 +35,15 @@ #pragma once #include +#include +#include #include "cutlass/gemm/thread/mma.h" #include "../kernel/thread/testbed_kernel.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/trace.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_fill.h" @@ -58,6 +61,78 @@ namespace test { namespace nvrtc { namespace thread { +#define NVRTC_RETURN_IF_ERROR(api) \ + do { \ + nvrtcResult _result = api; \ + if (_result != NVRTC_SUCCESS) { \ + CUTLASS_TRACE_HOST("Nvrtc error: " << _result); \ + return false; \ + } \ + } while(0) + +inline const char * cuda_source_fmt = R"""( + +#include "kernel/thread/contraction.hpp" + +using Operator = %s; + +extern "C" __global__ void global_entry(__grid_constant__ Operator::Params const params) { + extern __shared__ char smem[]; + + Operator op; + op(params, smem); +} + +)"""; + +struct TestbedKernel { + static bool compile(std::string const &kernel, std::vector const &opts) { + int sz = std::snprintf(nullptr, 0, cuda_source_fmt, kernel.c_str()); + std::vector cuda_source(sz + 1); + std::snprintf(&cuda_source[0], cuda_source.size(), cuda_source_fmt, kernel.c_str()); + + nvrtcProgram program; + NVRTC_RETURN_IF_ERROR( + nvrtcCreateProgram( + &program, + cuda_source.data(), + nullptr, + static_cast(cutlass::nvrtc::kCutlassHeaderCount), + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames) + ); + + nvrtcResult compile_result = + nvrtcCompileProgram( + program, + static_cast(opts.size()), + opts.data()); + + size_t log_size; + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLogSize(program, &log_size) + ); + + if (log_size > 1) { + auto log = std::make_unique(log_size); + + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLog(program, log.get()) + ); + + std::cout << log.get() << std::endl; + } + + NVRTC_RETURN_IF_ERROR(compile_result); + + NVRTC_RETURN_IF_ERROR( + nvrtcDestroyProgram(&program) + ); + + return true; + } +}; + /// Structure to compute the matrix product template < /// Size of the Gemm problem - concept: gemm::GemmShape<> diff --git a/test/unit/pipeline/pipeline_tma_async.cu b/test/unit/pipeline/pipeline_tma_async.cu index 0bd40b1f57..3253dfe293 100644 --- a/test/unit/pipeline/pipeline_tma_async.cu +++ b/test/unit/pipeline/pipeline_tma_async.cu @@ -79,7 +79,7 @@ void pipeline_device(uint32_t const NumIterations) using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - auto cta_layout = Layout{}; // (m,n) -> cta_id + [[maybe_unused]] auto cta_layout = Layout{}; // (m,n) -> cta_id int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int warp_group_thread_idx = threadIdx.x % 128; dim3 block_id_in_cluster = cute::block_id_in_cluster(); @@ -196,7 +196,7 @@ struct PipelineTest { float elapsed_ms = 0.0f; // Pipeline (multistage pipeline) - auto num_stages = Int{}; + [[maybe_unused]] auto num_stages = Int{}; auto cluster_shape = Shape, Int, _1>{}; diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu index 16a70a46f3..c6fa463a37 100644 --- a/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized.cu @@ -87,7 +87,7 @@ void pipeline_device(KernelParams const kernel_params) using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - auto cta_layout = Layout{}; // (m,n) -> cta_id + [[maybe_unused]] auto cta_layout = Layout{}; // (m,n) -> cta_id int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); int warp_group_thread_idx = threadIdx.x % 128; @@ -265,7 +265,7 @@ struct PipelineTest { float elapsed_ms = 0.0f; // Pipeline (multistage pipeline) - auto num_stages = Int{}; + [[maybe_unused]] auto num_stages = Int{}; auto cluster_shape = Shape, Int, _1>{}; // diff --git a/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu index 8fa645612a..f1e7e7f03d 100644 --- a/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu +++ b/test/unit/pipeline/pipeline_tma_async_warp_specialized_persistent.cu @@ -209,7 +209,7 @@ void pipeline_device(KernelParams params) using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); - auto cta_layout = Layout{}; // (m,n) -> cta_id + [[maybe_unused]] auto cta_layout = Layout{}; // (m,n) -> cta_id int warp_group_idx = __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); int warp_group_thread_idx = threadIdx.x % NumThreadsPerWarpGroup; dim3 block_id_in_cluster = cute::block_id_in_cluster(); diff --git a/test/unit/reduction/kernel/reduce_splitk.cu b/test/unit/reduction/kernel/reduce_splitk.cu index 6a990f481b..2f36d62a7e 100644 --- a/test/unit/reduction/kernel/reduce_splitk.cu +++ b/test/unit/reduction/kernel/reduce_splitk.cu @@ -110,7 +110,6 @@ public: cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); } else { - // TODO: Implement the rest EXPECT_TRUE(false) << "Not implemented"; return false; } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 606cf7f6cb..1d240bc4e2 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -25,12 +25,15 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + cmake_policy(SET CMP0112 NEW) + add_subdirectory(util) if (CUTLASS_ENABLE_LIBRARY) add_subdirectory(library) endif() + if (CUTLASS_ENABLE_PROFILER) if (NOT CUTLASS_ENABLE_LIBRARY) message(SEND_ERROR "Build conflict: The CUTLASS profiler requires the CUTLASS library.") diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 8745f39eee..ffb67910af 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -30,11 +30,6 @@ include(GNUInstallDirs) find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) -# Set Python3_EXECUTABLE to be visible from global scope. -# In CMake 3.24, this could be supported by adding the GLOBAL field -# to find_package above (https://cmake.org/cmake/help/latest/command/find_package.html#id7) -set(Python3_EXECUTABLE ${Python3_EXECUTABLE} CACHE INTERNAL "Path to python3 executable") - add_library(cutlass_library_includes INTERFACE) add_library(nvidia::cutlass::library::includes ALIAS cutlass_library_includes) set_target_properties(cutlass_library_includes PROPERTIES EXPORT_NAME library::includes) @@ -72,6 +67,7 @@ cutlass_add_library( src/util.cu src/reference/gemm.cu + src/reference/gemm_fp8.cu src/reference/initialize_reference_operations.cu @@ -93,6 +89,7 @@ file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOU # set cutlass generator compiler version to filter kernels in the generator not supported by a specific toolkit. set(CUTLASS_GENERATOR_CUDA_COMPILER_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) +set(CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/generated_kernels.txt CACHE STRING "Generated kernel listing file") # --log-level is set to DEBUG to enable printing information about which kernels were excluded # from generation in /tools/library/scripts/manifest.py. To avoid having this information appear @@ -107,6 +104,7 @@ execute_process( --architectures "${CUTLASS_NVCC_ARCHS_ENABLED}" --kernels "${CUTLASS_LIBRARY_KERNELS}" --ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}" + --selected-kernel-list "${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE}" --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" --log-level DEBUG RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT @@ -197,3 +195,9 @@ install( LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ) + +install( + FILES ${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE} + DESTINATION ${CMAKE_INSTALL_INFODIR}/cutlass/ + ) + diff --git a/tools/library/include/cutlass/library/descriptions.h b/tools/library/include/cutlass/library/descriptions.h new file mode 100644 index 0000000000..e866996529 --- /dev/null +++ b/tools/library/include/cutlass/library/descriptions.h @@ -0,0 +1,601 @@ +/*************************************************************************************************** + * Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct MathInstructionDescription { + + /// Shape of the target math instruction + cutlass::gemm::GemmCoord instruction_shape; + + /// Describes the data type of the internal accumulator + NumericTypeID element_accumulator; + + /// Classification of math instruction + OpcodeClassID opcode_class; + + /// Type of math operation performed + MathOperationID math_operation; + + // + // Methods + // + + MathInstructionDescription( + cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), + NumericTypeID element_accumulator = NumericTypeID::kInvalid, + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd + ): + instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} + + // Equality operator + inline + bool operator==(MathInstructionDescription const& rhs) const{ + return ( + (instruction_shape == rhs.instruction_shape) && + (element_accumulator == rhs.element_accumulator) && + (opcode_class == rhs.opcode_class) && + (math_operation == rhs.math_operation)); + } + + // Inequality operator + inline + bool operator!=(MathInstructionDescription const& rhs) const { + return !(*this == rhs); + } + +}; + +/// Structure describing the tiled structure of a GEMM-like computation +struct TileDescription { + + /// Describes the shape of a threadblock (in elements) + cutlass::gemm::GemmCoord threadblock_shape; + + /// Describes the number of pipeline stages in the threadblock-scoped mainloop + int threadblock_stages; + + /// Number of warps in each logical dimension + cutlass::gemm::GemmCoord warp_count; + + /// Core math instruction + MathInstructionDescription math_instruction; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int minimum_compute_capability; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int maximum_compute_capability; + + /// Describes the shape of a cluster (in blocks) + cutlass::gemm::GemmCoord cluster_shape; + + // + // Methods + // + + TileDescription( + cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), + int threadblock_stages = 0, + cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), + MathInstructionDescription math_instruction = MathInstructionDescription(), + int minimum_compute_capability = 0, + int maximum_compute_capability = 0, + cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) + ): + threadblock_shape(threadblock_shape), + threadblock_stages(threadblock_stages), + warp_count(warp_count), + math_instruction(math_instruction), + minimum_compute_capability(minimum_compute_capability), + maximum_compute_capability(maximum_compute_capability), + cluster_shape(cluster_shape) { } + + // Equality operator + inline + bool operator==(TileDescription const& rhs) const{ + return ( + (threadblock_shape == rhs.threadblock_shape) && + (threadblock_stages == rhs.threadblock_stages) && + (warp_count == rhs.warp_count) && + (math_instruction == rhs.math_instruction) && + (minimum_compute_capability == rhs.minimum_compute_capability) && + (maximum_compute_capability == rhs.maximum_compute_capability)); + } + + // Inequality operator + inline + bool operator!=(TileDescription const& rhs) const { + return !(*this == rhs); + } +}; + +/// High-level description of an operation +struct OperationDescription { + + /// Unique identifier describing the operation + char const * name; + + /// Operation provider + Provider provider; + + /// Kind of operation + OperationKind kind; + + /// Describes the tiled structure of a GEMM-like computation + TileDescription tile_description; + + // + // Methods + // + OperationDescription( + char const * name = "unknown", + Provider provider = Provider::kInvalid, + OperationKind kind = OperationKind::kInvalid, + TileDescription const& tile_description = TileDescription() + ): + name(name), provider(provider), kind(kind), tile_description(tile_description) { } +}; + +/// Structure describing the properties of a tensor +struct TensorDescription { + + /// Numeric type of an individual element + NumericTypeID element; + + /// Enumerant identifying the layout function for the tensor + LayoutTypeID layout; + + /// Alignment restriction on pointers, strides, and extents + int alignment; + + /// log2() of the maximum extent of each dimension + int log_extent_range; + + /// log2() of the maximum value each relevant stride may have + int log_stride_range; + + // + // Methods + // + + TensorDescription( + NumericTypeID element = NumericTypeID::kInvalid, + LayoutTypeID layout = LayoutTypeID::kInvalid, + int alignment = 1, + int log_extent_range = 24, + int log_stride_range = 24 + ): + element(element), + layout(layout), + alignment(alignment), + log_extent_range(log_extent_range), + log_stride_range(log_stride_range) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all GEMM computations +struct GemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the sparse meta matrices + TensorDescription E; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + GemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + GemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description for structured sparse GEMMs. +struct SparseGemmDescription : public GemmDescription { + + /// Description structure for structured sparse GEMM + SparseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + TensorDescription const& E = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) + {this->E = E;} +}; + +/// Description of all Reduction operations +struct ReductionDescription : public OperationDescription { + + /// Describes the data type of workspace + NumericTypeID element_workspace; + + /// Describes the data type of final output + NumericTypeID element_output; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; +}; + +/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) +struct RankKDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + RankKKind rank_k_kind; + + /// Number of rank update (rank k or rank 2k) + int num_ranks; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand (used only for SYR2K and HER2K) + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the fill mode for matrix C + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + RankKDescription( + RankKKind rank_k_kind = RankKKind::kUniversal, + int num_ranks = 1, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + rank_k_kind(rank_k_kind), + num_ranks(num_ranks), + A(A), + B(B), + C(C), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all TRMM computations +struct TrmmDescription : public OperationDescription { + + /// Indicates the kind of TRMM performed + TrmmKind trmm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the diag type for matrix A + DiagType diag_type; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription D; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + // + // Methods + // + + TrmmDescription( + TrmmKind trmm_kind = TrmmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + DiagType diag_type = DiagType::kInvalid, + TensorDescription const& B = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone + ): + trmm_kind(trmm_kind), + A(A), + side_mode(side_mode), + fill_mode(fill_mode), + diag_type(diag_type), + B(B), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all SYMM/HEMM update computations +struct SymmDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + SymmKind symm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + SymmDescription( + SymmKind symm_kind = SymmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + symm_kind(symm_kind), + A(A), + B(B), + C(C), + side_mode(side_mode), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all Conv2d operations +struct ConvDescription : public OperationDescription { + /// Describes the convolution dimension support (2D or 3D) + int conv_dim; + + /// Describes the kind of convolution + ConvKind conv_kind; + + /// Describes the type of iterator algorithm (analytic or precomputed) + IteratorAlgorithmID iterator_algorithm; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the C operand + TensorDescription C; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + // + // Methods + // + // Returns Activation TensorDescription + TensorDescription activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return A; + case library::ConvKind::kDgrad : return C; + case library::ConvKind::kWgrad : return B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter TensorDescription + TensorDescription filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return B; + case library::ConvKind::kDgrad : return B; + case library::ConvKind::kWgrad : return C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output TensorDescription + TensorDescription output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return C; + case library::ConvKind::kDgrad : return A; + case library::ConvKind::kWgrad : return A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index 387765e5ac..f298c6d56f 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -54,6 +54,8 @@ #include #include "cutlass/cutlass.h" +#include "cutlass/library/types.h" +#include "cutlass/library/descriptions.h" #include "cutlass/matrix_coord.h" #include "cutlass/tensor_coord.h" #include "cutlass/layout/tensor.h" @@ -71,751 +73,9 @@ namespace library { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Layout type identifier -enum class LayoutTypeID { - kUnknown, - kColumnMajor, - kRowMajor, - kColumnMajorInterleavedK2, - kRowMajorInterleavedK2, - kColumnMajorInterleavedK4, - kRowMajorInterleavedK4, - kColumnMajorInterleavedK16, - kRowMajorInterleavedK16, - kColumnMajorInterleavedK32, - kRowMajorInterleavedK32, - kColumnMajorInterleavedK64, - kRowMajorInterleavedK64, - kTensorNCHW, - kTensorNCDHW, - kTensorNHWC, - kTensorNDHWC, - kTensorNC32HW32, - kTensorC32RSK32, - kTensorNC64HW64, - kTensorC64RSK64, - kInvalid -}; - -/// Numeric data type -enum class NumericTypeID { - kUnknown, - kVoid, - kB1, - kU2, - kU4, - kU8, - kU16, - kU32, - kU64, - kS2, - kS4, - kS8, - kS16, - kS32, - kS64, - kFE4M3, - kFE5M2, - kF16, - kBF16, - kTF32, - kF32, - kF64, - kCF16, - kCBF16, - kCF32, - kCTF32, - kCF64, - kCS2, - kCS4, - kCS8, - kCS16, - kCS32, - kCS64, - kCU2, - kCU4, - kCU8, - kCU16, - kCU32, - kCU64, - kInvalid -}; - -/// Enumerated type describing a transformation on a complex value. -enum class ComplexTransform { - kNone, - kConjugate, - kInvalid -}; - -/// Providers -enum class Provider { - kNone, - kCUTLASS, - kReferenceHost, - kReferenceDevice, - kCUBLAS, - kCUDNN, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Enumeration indicating the kind of operation -enum class OperationKind { - kGemm, - kRankK, - kRank2K, - kTrmm, - kSymm, - kConv2d, - kConv3d, - kEqGemm, - kSparseGemm, - kReduction, - kInvalid -}; - -/// Enumeration indicating whether scalars are in host or device memory -enum class ScalarPointerMode { - kHost, - kDevice, - kInvalid -}; - -/// Describes how reductions are performed across threadblocks -enum class SplitKMode { - kNone, - kSerial, - kParallel, - kParallelSerial, - kInvalid -}; - -/// Indicates the classificaition of the math instruction -enum class OpcodeClassID { - kSimt, - kTensorOp, - kWmmaTensorOp, - kSparseTensorOp, - kInvalid -}; - -enum class MathOperationID { - kAdd, - kMultiplyAdd, - kMultiplyAddSaturate, - kMultiplyAddFastBF16, - kMultiplyAddFastF16, - kMultiplyAddFastF32, - kMultiplyAddComplex, - kMultiplyAddComplexFastF32, - kMultiplyAddGaussianComplex, - kXorPopc, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Enumeration indicating what kind of GEMM operation to perform -enum class GemmKind { - kGemm, - kSparse, - kUniversal, - kPlanarComplex, - kPlanarComplexArray, - kGrouped, - kInvalid -}; - /// Mode of Universal GEMM using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; -/// Enumeration indicating what kind of RankK update operation to perform -enum class RankKKind { - kUniversal, - kInvalid -}; - -/// Enumeration indicating what kind of TRMM operation to perform -enum class TrmmKind { - kUniversal, - kInvalid -}; - -/// Enumeration indicating what kind of SYMM/HEMM operation to perform -enum class SymmKind { - kUniversal, - kInvalid -}; - -/// Enumeration indicating what kind of Conv2d operation to perform -enum class ConvKind { - kUnknown, - kFprop, - kDgrad, - kWgrad, - kInvalid -}; - -enum class ConvModeID { - kCrossCorrelation, - kConvolution, - kInvalid -}; - -// Iterator algorithm enum in order of general performance-efficiency -enum class IteratorAlgorithmID { - kNone, - kAnalytic, - kOptimized, - kFixedChannels, - kFewChannels, - kInvalid -}; - - -enum class EpilogueKind { - kUnknown, - kConversion, - kLinearCombination, - kLinearCombinationClamp, - kLinearCombinationPlanarComplex, - kLinearCombinationRelu, - kLinearCombinationSigmoid, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -struct MathInstructionDescription { - - /// Shape of the target math instruction - cutlass::gemm::GemmCoord instruction_shape; - - /// Describes the data type of the internal accumulator - NumericTypeID element_accumulator; - - /// Classification of math instruction - OpcodeClassID opcode_class; - - /// Type of math operation performed - MathOperationID math_operation; - - // - // Methods - // - - MathInstructionDescription( - cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), - NumericTypeID element_accumulator = NumericTypeID::kInvalid, - OpcodeClassID opcode_class = OpcodeClassID::kInvalid, - MathOperationID math_operation = MathOperationID::kMultiplyAdd - ): - instruction_shape(instruction_shape), - element_accumulator(element_accumulator), - opcode_class(opcode_class), - math_operation(math_operation) {} - - // Equality operator - inline - bool operator==(MathInstructionDescription const& rhs) const{ - return ( - (instruction_shape == rhs.instruction_shape) && - (element_accumulator == rhs.element_accumulator) && - (opcode_class == rhs.opcode_class) && - (math_operation == rhs.math_operation)); - } - - // Inequality operator - inline - bool operator!=(MathInstructionDescription const& rhs) const { - return !(*this == rhs); - } - -}; - -/// Structure describing the tiled structure of a GEMM-like computation -struct TileDescription { - - /// Describes the shape of a threadblock (in elements) - cutlass::gemm::GemmCoord threadblock_shape; - - /// Describes the number of pipeline stages in the threadblock-scoped mainloop - int threadblock_stages; - - /// Number of warps in each logical dimension - cutlass::gemm::GemmCoord warp_count; - - /// Core math instruction - MathInstructionDescription math_instruction; - - /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. - int minimum_compute_capability; - - /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. - int maximum_compute_capability; - - /// Describes the shape of a cluster (in blocks) - cutlass::gemm::GemmCoord cluster_shape; - - // - // Methods - // - - TileDescription( - cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), - int threadblock_stages = 0, - cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), - MathInstructionDescription math_instruction = MathInstructionDescription(), - int minimum_compute_capability = 0, - int maximum_compute_capability = 0, - cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) - ): - threadblock_shape(threadblock_shape), - threadblock_stages(threadblock_stages), - warp_count(warp_count), - math_instruction(math_instruction), - minimum_compute_capability(minimum_compute_capability), - maximum_compute_capability(maximum_compute_capability), - cluster_shape(cluster_shape) { } - - // Equality operator - inline - bool operator==(TileDescription const& rhs) const{ - return ( - (threadblock_shape == rhs.threadblock_shape) && - (threadblock_stages == rhs.threadblock_stages) && - (warp_count == rhs.warp_count) && - (math_instruction == rhs.math_instruction) && - (minimum_compute_capability == rhs.minimum_compute_capability) && - (maximum_compute_capability == rhs.maximum_compute_capability)); - } - - // Inequality operator - inline - bool operator!=(TileDescription const& rhs) const { - return !(*this == rhs); - } -}; - -/// High-level description of an operation -struct OperationDescription { - - /// Unique identifier describing the operation - char const * name; - - /// Operation provider - Provider provider; - - /// Kind of operation - OperationKind kind; - - /// Describes the tiled structure of a GEMM-like computation - TileDescription tile_description; - - // - // Methods - // - OperationDescription( - char const * name = "unknown", - Provider Provider = Provider::kInvalid, - OperationKind kind = OperationKind::kInvalid, - TileDescription const & tile_description = TileDescription() - ): - name(name), kind(kind), tile_description(tile_description) { } -}; - -/// Structure describing the properties of a tensor -struct TensorDescription { - - /// Numeric type of an individual element - NumericTypeID element; - - /// Enumerant identifying the layout function for the tensor - LayoutTypeID layout; - - /// Alignment restriction on pointers, strides, and extents - int alignment; - - /// log2() of the maximum extent of each dimension - int log_extent_range; - - /// log2() of the maximum value each relevant stride may have - int log_stride_range; - - // - // Methods - // - - TensorDescription( - NumericTypeID element = NumericTypeID::kInvalid, - LayoutTypeID layout = LayoutTypeID::kInvalid, - int alignment = 1, - int log_extent_range = 24, - int log_stride_range = 24 - ): - element(element), - layout(layout), - alignment(alignment), - log_extent_range(log_extent_range), - log_stride_range(log_stride_range) { } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all GEMM computations -struct GemmDescription : public OperationDescription { - - /// Indicates the kind of GEMM performed - GemmKind gemm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source matrix - TensorDescription C; - - /// Describes the destination matrix - TensorDescription D; - - /// Describes the sparse meta matrices - TensorDescription E; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - // - // Methods - // - - GemmDescription( - GemmKind gemm_kind = GemmKind::kGemm, - TensorDescription const &A = TensorDescription(), - TensorDescription const &B = TensorDescription(), - TensorDescription const &C = TensorDescription(), - TensorDescription const &D = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - gemm_kind(gemm_kind), - A(A), - B(B), - C(C), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description for structured sparse GEMMs. -struct SparseGemmDescription : public GemmDescription { - - /// Description structure for structured sparse GEMM - SparseGemmDescription( - GemmKind gemm_kind = GemmKind::kGemm, - TensorDescription const &A = TensorDescription(), - TensorDescription const &B = TensorDescription(), - TensorDescription const &C = TensorDescription(), - TensorDescription const &D = TensorDescription(), - TensorDescription const &E = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) - {this->E = E;} -}; - -/// Description of all Reduction operations -struct ReductionDescription : public OperationDescription { - - /// Describes the data type of workspace - NumericTypeID element_workspace; - - /// Describes the data type of final output - NumericTypeID element_output; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; -}; - -/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) -struct RankKDescription : public OperationDescription { - - /// Indicates which device template is used (universal or regular) - RankKKind rank_k_kind; - - /// Number of rank update (rank k or rank 2k) - int num_ranks; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand (used only for SYR2K and HER2K) - TensorDescription B; - - /// Describes the source and destination matrices - TensorDescription C; - - /// Describes the fill mode for matrix C - FillMode fill_mode; - - /// Describes the blas mode (symmetric/hermitian) - BlasMode blas_mode; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - // - // Methods - // - - RankKDescription( - RankKKind rank_k_kind = RankKKind::kUniversal, - int num_ranks = 1, - TensorDescription const &A = TensorDescription(), - TensorDescription const &B = TensorDescription(), - TensorDescription const &C = TensorDescription(), - FillMode fill_mode = FillMode::kInvalid, - BlasMode blas_mode = BlasMode::kInvalid, - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - rank_k_kind(rank_k_kind), - num_ranks(num_ranks), - A(A), - B(B), - C(C), - fill_mode(fill_mode), - blas_mode(blas_mode), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all TRMM computations -struct TrmmDescription : public OperationDescription { - - /// Indicates the kind of TRMM performed - TrmmKind trmm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the side mode for matrix A - SideMode side_mode; - - /// Describes the fill mode for matrix A - FillMode fill_mode; - - /// Describes the diag type for matrix A - DiagType diag_type; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source and destination matrices - TensorDescription D; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - // - // Methods - // - - TrmmDescription( - TrmmKind trmm_kind = TrmmKind::kUniversal, - TensorDescription const &A = TensorDescription(), - SideMode side_mode = SideMode::kInvalid, - FillMode fill_mode = FillMode::kInvalid, - DiagType diag_type = DiagType::kInvalid, - TensorDescription const &B = TensorDescription(), - TensorDescription const &D = TensorDescription(), - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone - ): - trmm_kind(trmm_kind), - A(A), - side_mode(side_mode), - fill_mode(fill_mode), - diag_type(diag_type), - B(B), - D(D), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all SYMM/HEMM update computations -struct SymmDescription : public OperationDescription { - - /// Indicates which device template is used (universal or regular) - SymmKind symm_kind; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the source and destination matrices - TensorDescription C; - - /// Describes the side mode for matrix A - SideMode side_mode; - - /// Describes the fill mode for matrix A - FillMode fill_mode; - - /// Describes the blas mode (symmetric/hermitian) - BlasMode blas_mode; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - /// Describes the structure of parallel reductions - SplitKMode split_k_mode; - - /// Transformation on A operand - ComplexTransform transform_A; - - /// Transformation on B operand - ComplexTransform transform_B; - - // - // Methods - // - - SymmDescription( - SymmKind symm_kind = SymmKind::kUniversal, - TensorDescription const &A = TensorDescription(), - TensorDescription const &B = TensorDescription(), - TensorDescription const &C = TensorDescription(), - SideMode side_mode = SideMode::kInvalid, - FillMode fill_mode = FillMode::kInvalid, - BlasMode blas_mode = BlasMode::kInvalid, - NumericTypeID element_epilogue = NumericTypeID::kInvalid, - SplitKMode split_k_mode = SplitKMode::kNone, - ComplexTransform transform_A = ComplexTransform::kNone, - ComplexTransform transform_B = ComplexTransform::kNone - ): - symm_kind(symm_kind), - A(A), - B(B), - C(C), - side_mode(side_mode), - fill_mode(fill_mode), - blas_mode(blas_mode), - element_epilogue(element_epilogue), - split_k_mode(split_k_mode), - transform_A(transform_A), - transform_B(transform_B) {} -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Description of all Conv2d operations -struct ConvDescription : public OperationDescription { - /// Describes the convolution dimension support (2D or 3D) - int conv_dim; - - /// Describes the kind of convolution - ConvKind conv_kind; - - /// Describes the type of iterator algorithm (analytic or precomputed) - IteratorAlgorithmID iterator_algorithm; - - /// Describes the A operand - TensorDescription A; - - /// Describes the B operand - TensorDescription B; - - /// Describes the C operand - TensorDescription C; - - /// Describes the data type of the scalars passed to the epilogue - NumericTypeID element_epilogue; - - // - // Methods - // - // Returns Activation TensorDescription - TensorDescription activation() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return A; - case library::ConvKind::kDgrad : return C; - case library::ConvKind::kWgrad : return B; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns Filter TensorDescription - TensorDescription filter() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return B; - case library::ConvKind::kDgrad : return B; - case library::ConvKind::kWgrad : return C; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - - // Returns Output TensorDescription - TensorDescription output() const { - switch(conv_kind) { - case library::ConvKind::kFprop : return C; - case library::ConvKind::kDgrad : return A; - case library::ConvKind::kWgrad : return A; - default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); - } - } - -}; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Base class for all operations @@ -1030,6 +290,7 @@ struct GemmUniversalArguments { // Needed for some 3.x kernels int sm_count; + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/include/cutlass/library/types.h b/tools/library/include/cutlass/library/types.h new file mode 100644 index 0000000000..9f0673f93d --- /dev/null +++ b/tools/library/include/cutlass/library/types.h @@ -0,0 +1,258 @@ +/*************************************************************************************************** + * Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + #pragma once + + ///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Layout type identifier +enum class LayoutTypeID { + kUnknown, + kColumnMajor, + kRowMajor, + kColumnMajorInterleavedK2, + kRowMajorInterleavedK2, + kColumnMajorInterleavedK4, + kRowMajorInterleavedK4, + kColumnMajorInterleavedK16, + kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, + kTensorNCHW, + kTensorNCDHW, + kTensorNHWC, + kTensorNDHWC, + kTensorNC32HW32, + kTensorC32RSK32, + kTensorNC64HW64, + kTensorC64RSK64, + kInvalid +}; + +/// Numeric data type +enum class NumericTypeID { + kUnknown, + kVoid, + kB1, + kU2, + kU4, + kU8, + kU16, + kU32, + kU64, + kS2, + kS4, + kS8, + kS16, + kS32, + kS64, + kFE4M3, + kFE5M2, + kF16, + kBF16, + kTF32, + kF32, + kF64, + kCF16, + kCBF16, + kCF32, + kCTF32, + kCF64, + kCS2, + kCS4, + kCS8, + kCS16, + kCS32, + kCS64, + kCU2, + kCU4, + kCU8, + kCU16, + kCU32, + kCU64, + kInvalid +}; + +/// Enumerated type describing a transformation on a complex value. +enum class ComplexTransform { + kNone, + kConjugate, + kInvalid +}; + +/// Providers +enum class Provider { + kNone, + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kCUDNN, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating the kind of operation +enum class OperationKind { + kGemm, + kRankK, + kRank2K, + kTrmm, + kSymm, + kConv2d, + kConv3d, + kEqGemm, + kSparseGemm, + kReduction, + kInvalid +}; + +/// Enumeration indicating whether scalars are in host or device memory +enum class ScalarPointerMode { + kHost, + kDevice, + kInvalid +}; + +/// Describes how reductions are performed across threadblocks +enum class SplitKMode { + kNone, + kSerial, + kParallel, + kParallelSerial, + kInvalid +}; + +/// Indicates the classificaition of the math instruction +enum class OpcodeClassID { + kSimt, + kTensorOp, + kWmmaTensorOp, + kSparseTensorOp, + kInvalid +}; + +enum class MathOperationID { + kAdd, + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddFastBF16, + kMultiplyAddFastF16, + kMultiplyAddFastF32, + kMultiplyAddComplex, + kMultiplyAddComplexFastF32, + kMultiplyAddGaussianComplex, + kXorPopc, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating what kind of GEMM operation to perform +enum class GemmKind { + kGemm, + kSparse, + kUniversal, + kPlanarComplex, + kPlanarComplexArray, + kGrouped, + kInvalid +}; + +/// Enumeration indicating what kind of RankK update operation to perform +enum class RankKKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of TRMM operation to perform +enum class TrmmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of SYMM/HEMM operation to perform +enum class SymmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of Conv2d operation to perform +enum class ConvKind { + kUnknown, + kFprop, + kDgrad, + kWgrad, + kInvalid +}; + +enum class ConvModeID { + kCrossCorrelation, + kConvolution, + kInvalid +}; + +// Iterator algorithm enum in order of general performance-efficiency +enum class IteratorAlgorithmID { + kNone, + kAnalytic, + kOptimized, + kFixedChannels, + kFewChannels, + kInvalid +}; + + +enum class EpilogueKind { + kUnknown, + kConversion, + kLinearCombination, + kLinearCombinationClamp, + kLinearCombinationPlanarComplex, + kLinearCombinationRelu, + kLinearCombinationSigmoid, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index 0be042f9bf..a0370edebb 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -24,7 +24,8 @@ class GemmOperation: # def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, - kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto): + kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, + tile_scheduler = TileSchedulerType.Default): self.prefix = "3x" if gemm_kind == GemmKind.Universal3x else "" self.operation_kind = OperationKind.Gemm @@ -46,6 +47,7 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, self.element_epilogue = element_epilogue self.epilogue_functor = epilogue_functor self.swizzling_functor = swizzling_functor + self.tile_scheduler = tile_scheduler # def is_complex(self): @@ -86,6 +88,7 @@ def core_name(self): math_operations_map = { MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' } if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ @@ -176,7 +179,7 @@ def procedural_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] if self.arch >= 90: - kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}" + kernel_name_template = "cutlass{p}_sm{ar}_{op}_{ex}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{l}_{s}_align{al}{k}{e}{t}" return kernel_name_template.format( p = self.prefix, ar = self.arch, @@ -192,7 +195,8 @@ def procedural_name(self): s = self.layout_name_3x(), al = str(max(self.A.alignment, self.B.alignment)), k = self.kernel_schedule_name_3x(), - e = self.epilogue_schedule_name_3x()) + e = self.epilogue_schedule_name_3x(), + t = TileSchedulerSuffixes[self.tile_scheduler]) else: threadblock = self.tile_description.procedural_name() return "cutlass{p}_{op}_{ex}_{tb}_{l}_align{a}".format( @@ -666,7 +670,8 @@ def __init__(self, operation_suffix = ''): using ${operation_name}_base = cutlass::gemm::kernel::GemmUniversal< cute::Shape, ${operation_name}_mainloop, - ${operation_name}_epilogue>; + ${operation_name}_epilogue, + ${tile_scheduler}>; // Define named type struct ${operation_name} : @@ -752,6 +757,7 @@ def emit(self, operation): 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]) } return SubstituteTemplate(self.gemm_template, values) diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 09d92e9aa1..630364c162 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -95,7 +95,8 @@ def CreateGemmUniversal3xOperator( schedules = [[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto]], complex_transforms=None, epilogue_functor=EpilogueFunctor.LinearCombination, - swizzling_functor=SwizzlingFunctor.Identity1): + swizzling_functor=SwizzlingFunctor.Identity1, + tile_schedulers=[TileSchedulerType.Persistent]): if type(data_types) is dict: data_types = [data_types] @@ -112,27 +113,25 @@ def CreateGemmUniversal3xOperator( if manifest.kernel_filter == '': tile_descriptions = [tile_descriptions[0]] - for layout in layouts: - for tile_description in tile_descriptions: - for data_type in data_types: - for complex_transform in complex_transforms: - for kernel_schedule, epilogue_schedule in schedules: - A = TensorDescription( - data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) - B = TensorDescription( - data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) + combinations = product(layouts, tile_descriptions, data_types, complex_transforms, schedules, tile_schedulers) + for layout, tile_description, data_type, complex_transform, schedules, tile_scheduler in combinations: + kernel_schedule, epilogue_schedule = schedules + A = TensorDescription( + data_type["a_type"], layout[0][0], layout[0][1], complex_transform[0]) + B = TensorDescription( + data_type["b_type"], layout[1][0], layout[1][1], complex_transform[1]) - C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) - D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) + D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) - element_compute = data_type.get("epi_type", data_type["acc_type"]) - operation = GemmOperation( - GemmKind.Universal3x, tile_description.minimum_compute_capability, - tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, - kernel_schedule, epilogue_schedule) + element_compute = data_type.get("epi_type", data_type["acc_type"]) + operation = GemmOperation( + GemmKind.Universal3x, tile_description.minimum_compute_capability, + tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, + kernel_schedule, epilogue_schedule, tile_scheduler) - manifest.append(operation) - operations.append(operation) + manifest.append(operation) + operations.append(operation) return operations @@ -1006,7 +1005,7 @@ def GenerateSM61_Simt(manifest, cuda_version): data_type, alignment_constraints) CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) # # @@ -1735,19 +1734,23 @@ def GenerateSM75_TensorOp_88128(manifest, cuda_version): ] min_cc = 75 - max_cc = 1024 + max_cc = { + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 + } + alignment_constraints = [128,] for math_inst in math_instructions: tile_descriptions = [ - TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 256, 512], 2, [1, 4, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([256, 64, 512], 2, [4, 1, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), + TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc[math_inst.math_operation]), ] data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] @@ -2151,17 +2154,25 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 6, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 64], 6, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] @@ -2191,7 +2202,10 @@ def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): for op in operations: if op.tile_description.threadblock_shape[1] >= 128: - op.C.alignment = 16 + if op.tile_description.threadblock_shape[0] == 32: + op.C.alignment = 8 + else: + op.C.alignment = 16 else: op.C.alignment = 8 @@ -2504,11 +2518,17 @@ def GenerateSM80_TensorOp_168256(manifest, cuda_version): DataType.b1, DataType.b1, DataType.s32, \ OpcodeClass.TensorOp, \ MathOperation.xor_popc), + MathInstruction( \ + [16, 8, 256], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.and_popc), ] min_cc = 80 max_cc = { - MathOperation.xor_popc: 1024 + MathOperation.xor_popc: 89, + MathOperation.and_popc: 90 } alignment_constraints = [128,] @@ -4133,15 +4153,21 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] ] + stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]] else: schedules = [ [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] # TmaWarpSpecializedCooperative and TmaWarpSpecializedPingpong require CUDA version >= 12.1 for optimal performance. ] + stream_k_schedules = [] CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules) + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # Add stream-K variants + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) + # persistent kernels with TMA epilogues if CudaToolkitVersionSatisfies(cuda_version, 12, 1): # not enough smem for 256x128 f32 out with C allocation @@ -4149,16 +4175,29 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + tile_schedulers=[TileSchedulerType.StreamK]) else: CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + tile_schedulers=[TileSchedulerType.StreamK]) + # Emit instance without C allocation + load data_type["c_type"] = DataType.void CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + tile_schedulers=[TileSchedulerType.StreamK]) + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: @@ -4184,12 +4223,21 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + tile_schedulers=[TileSchedulerType.StreamK]) + # Emit instance without C allocation+load data_type_mixed["c_type"] = DataType.void CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, [[KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + tile_schedulers=[TileSchedulerType.StreamK]) + # def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): @@ -4209,9 +4257,28 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): OpcodeClass.TensorOp, MathOperation.multiply_add) + math_inst_largeN = MathInstruction( + [64, 256, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + min_cc = 90 max_cc = 90 + tile_descriptions_large = [ + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst_largeN.instruction_shape[0]*2, math_inst_largeN.instruction_shape[1], math_inst_largeN.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst_largeN, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions_medium = [ TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), @@ -4220,6 +4287,7 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), ] + tile_descriptions_small = [ TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), @@ -4261,28 +4329,43 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.EpilogueTransposed] ] - # TMA kernels with TN or NN layout - layouts_tf32_tn_nn = [layouts_tf32[0], layouts_tf32[2]] + # TMA kernels with TN, NN, or NT layout + layouts_tf32_tn_nn_nt = [layouts_tf32[0], layouts_tf32[2], layouts_tf32[3]] + # TMA kernels with TT layout + layouts_tf32_tt = [layouts_tf32[1]] + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): - CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_small, data_types, [ + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions_small, data_types, [ + [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized] + ]) + + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions_medium, data_types, [ [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized] ]) - CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions_medium, data_types, [ + + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions_large, data_types, [ + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], + ]) + + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions_medium, data_types, [ [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized] ]) - else: - CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn, tile_descriptions, data_types, schedules_default) - # TMA kernels with NT layout, only support 64x128x32 tile for now. - layouts_tf32_nt = [layouts_tf32[3]] - CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_small, data_types, schedules_default) - CreateGemmUniversal3xOperator(manifest, layouts_tf32_nt, tile_descriptions_medium, data_types, [ - [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] - ]) - - layouts_tf32_tt = [layouts_tf32[1]] + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions_small, data_types, [ + [KernelScheduleType.TmaWarpSpecializedPingpong, EpilogueScheduleType.EpilogueTransposed] + ]) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions_medium, data_types, [ + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.EpilogueTransposed] + ]) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions_large, data_types, [ + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.EpilogueTransposed] + ]) + else: + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions, data_types, schedules_default) + CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue) # @@ -4378,7 +4461,220 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): # Cooperative persistent CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_medium, data_type, [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], - [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]]) + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]], + tile_schedulers=[TileSchedulerType.Persistent, TileSchedulerType.StreamK] + ) + +def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 16], [LayoutType.ColumnMajor, 16], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = [ + # inst 64x128x32 + MathInstruction( + [64, 128, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # inst 64x64x32 + MathInstruction( + [64, 64, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 64, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 64, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 64, 32], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + min_cc = 90 + max_cc = 90 + + for math_inst in math_instructions: + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + ] + + if math_inst.instruction_shape[1] == 128: + tile_descriptions_small = [ + # 64x128x128 + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions = [ + # 128x128x128 + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + + elif math_inst.instruction_shape[1] == 64: + tile_descriptions = [ + # 256x64x128 + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,2,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [2,1,1]), + TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + + else: + assert False, "math inst is not supported" + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + schedules = [ + [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], + [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized] + ] + stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]] + else: + schedules = [ + [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], + [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] + # TmaWarpSpecializedCooperative require CUDA version >= 12.1 for optimal performance. + ] + stream_k_schedules = [] + + + for data_type in data_types: + # With No-SMEM epilogues + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # Persistent kernels with TMA epilogues + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + # Small tiles + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions_small, data_type, + [[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.TmaWarpSpecialized], + [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]) + + # Add stream-K variants (with and without TMA epilogues) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative], + [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + tile_schedulers=[TileSchedulerType.StreamK]) + # def GenerateSM90_TensorOp_1684(manifest, cuda_version): @@ -4400,7 +4696,7 @@ def GenerateSM90_TensorOp_1684(manifest, cuda_version): MathOperation.multiply_add) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4448,7 +4744,7 @@ def GenerateSM90_TensorOp_1684_complex(manifest, cuda_version): MathOperation.multiply_add_complex) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4505,7 +4801,7 @@ def GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version): MathOperation.multiply_add_complex_gaussian) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4554,7 +4850,7 @@ def GenerateSM90_TensorOp_1684_rank_k(manifest, cuda_version): MathOperation.multiply_add) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4599,7 +4895,7 @@ def GenerateSM90_TensorOp_1684_rank_k_complex(manifest, cuda_version): MathOperation.multiply_add_complex) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4649,7 +4945,7 @@ def GenerateSM90_TensorOp_1684_rank_k_complex_gaussian(manifest, cuda_version): MathOperation.multiply_add_complex_gaussian) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4706,7 +5002,7 @@ def GenerateSM90_TensorOp_1684_trmm(manifest, cuda_version): MathOperation.multiply_add) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4754,7 +5050,7 @@ def GenerateSM90_TensorOp_1684_trmm_complex(manifest, cuda_version): MathOperation.multiply_add_complex) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4808,7 +5104,7 @@ def GenerateSM90_TensorOp_1684_trmm_complex_gaussian(manifest, cuda_version): MathOperation.multiply_add_complex_gaussian) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4854,7 +5150,7 @@ def GenerateSM90_TensorOp_1684_symm(manifest, cuda_version): MathOperation.multiply_add) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4902,7 +5198,7 @@ def GenerateSM90_TensorOp_1684_symm_complex(manifest, cuda_version): MathOperation.multiply_add_complex) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4954,7 +5250,7 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): MathOperation.multiply_add_complex_gaussian) min_cc = 90 - max_cc = 1024 + max_cc = 90 alignment_constraints = [1,] @@ -4987,6 +5283,7 @@ def GenerateSM90(manifest, cuda_version): GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version) GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version) GenerateSM90_TensorOp_1684(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index 2f0fcfb4cc..88ec518ec1 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -263,6 +263,7 @@ class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() xor_popc = enum_auto() + and_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() multiply_add_fast_f32 = enum_auto() @@ -275,6 +276,7 @@ class MathOperation(enum.Enum): MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', + MathOperation.and_popc: 'cutlass::arch::OpAndPopc', MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', MathOperation.multiply_add_fast_f32: 'cutlass::arch::OpMultiplyAddFastF32', @@ -373,6 +375,9 @@ class KernelScheduleType(enum.Enum): TmaWarpSpecialized = enum_auto() TmaWarpSpecializedPingpong = enum_auto() TmaWarpSpecializedCooperative = enum_auto() + TmaWarpSpecializedFP8FastAccum = enum_auto() + TmaWarpSpecializedCooperativeFP8FastAccum = enum_auto() + TmaWarpSpecializedPingpongFP8FastAccum = enum_auto() # KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', @@ -381,6 +386,9 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', KernelScheduleType.TmaWarpSpecializedCooperative: 'cutlass::gemm::KernelTmaWarpSpecializedCooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: 'cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum', } # @@ -391,6 +399,9 @@ class KernelScheduleType(enum.Enum): KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', KernelScheduleType.TmaWarpSpecializedCooperative: '_warpspecialized_cooperative', + KernelScheduleType.TmaWarpSpecializedFP8FastAccum: '_warpspecialized_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum: '_warpspecialized_cooperative_fp8_fastaccum', + KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum: '_warpspecialized_pingpong_fp8_fastaccum', } class EpilogueScheduleType(enum.Enum): @@ -417,6 +428,24 @@ class EpilogueScheduleType(enum.Enum): EpilogueScheduleType.TmaWarpSpecializedCooperative: '_epi_tma', } +class TileSchedulerType(enum.Enum): + Default = enum_auto() + Persistent = enum_auto() + StreamK = enum_auto() +# +TileSchedulerTag = { + TileSchedulerType.Default: 'void', + TileSchedulerType.Persistent: 'cutlass::gemm::PersistentScheduler', + TileSchedulerType.StreamK: 'cutlass::gemm::StreamKScheduler', +} + +# +TileSchedulerSuffixes = { + TileSchedulerType.Default: '', + TileSchedulerType.Persistent: '', + TileSchedulerType.StreamK: '_stream_k', +} + ################################################################################################### # diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 48fdccb56c..d4f0483bda 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -337,7 +337,7 @@ def filter(self, operation): enabled = False - # todo: filter based on compute data type + # TODO: filter based on compute data type return enabled # diff --git a/tools/library/scripts/rank_2k_operation.py b/tools/library/scripts/rank_2k_operation.py index ebb7c3eb80..df43ca9de4 100644 --- a/tools/library/scripts/rank_2k_operation.py +++ b/tools/library/scripts/rank_2k_operation.py @@ -80,6 +80,7 @@ def core_name(self): math_operations_map = { MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' } if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ diff --git a/tools/library/scripts/rank_k_operation.py b/tools/library/scripts/rank_k_operation.py index 5ff596be4b..74ce78ac26 100644 --- a/tools/library/scripts/rank_k_operation.py +++ b/tools/library/scripts/rank_k_operation.py @@ -78,6 +78,7 @@ def core_name(self): math_operations_map = { MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' } if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ diff --git a/tools/library/scripts/symm_operation.py b/tools/library/scripts/symm_operation.py index dbee11b4c7..af01fdc90c 100644 --- a/tools/library/scripts/symm_operation.py +++ b/tools/library/scripts/symm_operation.py @@ -80,6 +80,7 @@ def core_name(self): math_operations_map = { MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' } if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ diff --git a/tools/library/scripts/trmm_operation.py b/tools/library/scripts/trmm_operation.py index 7e03e278ef..c234e6f9ce 100644 --- a/tools/library/scripts/trmm_operation.py +++ b/tools/library/scripts/trmm_operation.py @@ -78,6 +78,7 @@ def core_name(self): math_operations_map = { MathOperation.xor_popc: 'xor', + MathOperation.and_popc: 'and' } if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index eec57169cf..4f9e39be8e 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -33,11 +33,11 @@ */ #pragma once + #include "cutlass/cutlass.h" #include "cutlass/library/library.h" #include "library_internal.h" - /////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::library { @@ -60,7 +60,7 @@ class GemmOperation3xBase : public Operation { // assuming all tensors use same type for StrideIndex using StrideIndex = typename Operator::LayoutA::Index; using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::CollectiveEpilogue::ElementCompute; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; private: @@ -130,6 +130,11 @@ class GemmOperation3xBase : public Operation { virtual OperationDescription const & description() const { return description_; } + + /// Returns the description of the GEMM operation + GemmDescription const& get_gemm_description() const { + return description_; + } }; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -159,8 +164,7 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { /// Constructor GemmUniversal3xOperation(char const *name = "unknown_gemm"): - GemmOperation3xBase(name, GemmKind::kUniversal) { - } + GemmOperation3xBase(name, GemmKind::kUniversal) {} protected: @@ -175,23 +179,49 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { return Status::kSuccess; } + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + /// Constructs the arguments structure given the configuration and arguments static Status update_arguments_( OperatorArguments &operator_args, GemmUniversalArguments const *arguments) { - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename ThreadEpilogueOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta)); - operator_args.epilogue.thread = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { - typename ThreadEpilogueOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta)); - operator_args.epilogue.thread = params; - } - else { - return Status::kErrorInvalidProblem; + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; } // TODO: type erase Arguments structure in 3.0 GEMM @@ -218,7 +248,7 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { /* Query device SM count to pass onto the kernel as an argument, where needed */ operator_args.hw_info.sm_count = arguments->sm_count; - return Status::kSuccess; + return status; } public: @@ -297,7 +327,6 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { return status; } }; - /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::library diff --git a/tools/library/src/reference/gemm.cu b/tools/library/src/reference/gemm.cu index 77dcfad243..e314155c1f 100644 --- a/tools/library/src/reference/gemm.cu +++ b/tools/library/src/reference/gemm.cu @@ -179,6 +179,16 @@ void initialize_gemm_reference_operations(Manifest &manifest) { NumericConverterClamp >(manifest); + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int8_t, + int32_t, + int32_t, + int8_t, + NumericConverterClamp + >(manifest); + make_gemm_interleaved_layouts< 32, int8_t, @@ -344,7 +354,6 @@ void initialize_gemm_reference_operations(Manifest &manifest) { complex, complex >(manifest); - } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/reference/gemm_fp8.cu b/tools/library/src/reference/gemm_fp8.cu new file mode 100644 index 0000000000..a5c119fffc --- /dev/null +++ b/tools/library/src/reference/gemm_fp8.cu @@ -0,0 +1,418 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations for FP8. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_fp8_reference_operations(Manifest &manifest) { + // + // FP8 GEMMs + // + ////////////////////////////////// + /// ElementC: half_t + ////////////////////////////////// + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float , // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float , // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + half_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + half_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + /// ElementC: bfloat16_t + ////////////////////////////////// + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + bfloat16_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + bfloat16_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + ////////////////////////////////// + /// ElementC: float + ////////////////////////////////// + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e4m3_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + ////////////////////////////////// + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e5m2_t, // ElementA + float_e5m2_t, // ElementB + float, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e5m2_t // ElementD + >(manifest); + + make_gemm_real_canonical_layouts< + float_e4m3_t, // ElementA + float_e4m3_t, // ElementB + float_e4m3_t, // ElementC + float, // ElementScalar + float, // ElementAccumulator + float_e4m3_t // ElementD + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index b63367e40b..1b3efebc37 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -43,6 +43,7 @@ namespace cutlass { namespace library { void initialize_gemm_reference_operations(Manifest &manifest); +void initialize_gemm_fp8_reference_operations(Manifest &manifest); void initialize_conv2d_reference_operations(Manifest &manifest); void initialize_conv3d_reference_operations(Manifest &manifest); @@ -52,6 +53,7 @@ void initialize_reference_operations(Manifest &manifest) { initialize_conv2d_reference_operations(manifest); initialize_conv3d_reference_operations(manifest); initialize_gemm_reference_operations(manifest); + initialize_gemm_fp8_reference_operations(manifest); } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 368f8b9a61..16cb9051ed 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -26,6 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) + # # Sources for CUTLASS Profiler Tool # diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index a4f3778085..e4db2290ed 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -130,7 +130,6 @@ int CutlassProfiler::operator()() { // Enumerates all operations enumerate_(); } - return 0; } diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 600950e2d8..2c6fdc3ffc 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -938,6 +938,465 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { copy_from_host(host_data.data()); } +void DeviceAllocation::initialize_sequential_device(Distribution dist) { + if (!bytes()) { +#ifndef NDEBUG + std::cout << "Skipping initialization of size 0 allocation\n"; +#endif + return; + } + + if (!data()) { + throw std::runtime_error("Attempting to initialize invalid allocation."); + } + + switch (type_) { + case library::NumericTypeID::kFE4M3: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFE5M2: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kF16: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kBF16: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kTF32: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kF32: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kCF16: + cutlass::reference::device::BlockFillSequential>( + reinterpret_cast *>(pointer_), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kCBF16: + cutlass::reference::device::BlockFillSequential>( + reinterpret_cast *>(pointer_), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kCTF32: + cutlass::reference::device::BlockFillSequential>( + reinterpret_cast *>(pointer_), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kCF32: + cutlass::reference::device::BlockFillSequential>( + reinterpret_cast *>(pointer_), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kF64: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kCF64: + cutlass::reference::device::BlockFillSequential>( + reinterpret_cast *>(pointer_), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kS2: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS4: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS8: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS16: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS32: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS64: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kB1: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU2: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU4: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU8: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU16: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU32: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU64: + cutlass::reference::device::BlockFillSequential( + reinterpret_cast(pointer_), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + default: break; + } + +} + +void DeviceAllocation::initialize_sequential_host(Distribution dist) { + if (!bytes()) { +#ifndef NDEBUG + std::cout << "Skipping initialization of size 0 allocation\n"; +#endif + return; + } + + if (!data()) { + throw std::runtime_error("Attempting to initialize invalid allocation."); + } + + std::vector host_data(bytes()); + + switch (type_) { + case library::NumericTypeID::kFE4M3: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kFE5M2: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kF16: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kBF16: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kTF32: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kF32: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kCF16: + cutlass::reference::host::BlockFillSequential>( + reinterpret_cast *>(host_data.data()), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kCBF16: + cutlass::reference::host::BlockFillSequential>( + reinterpret_cast *>(host_data.data()), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kCTF32: + cutlass::reference::host::BlockFillSequential>( + reinterpret_cast *>(host_data.data()), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kCF32: + cutlass::reference::host::BlockFillSequential>( + reinterpret_cast *>(host_data.data()), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kF64: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kCF64: + cutlass::reference::host::BlockFillSequential>( + reinterpret_cast *>(host_data.data()), + capacity_, + cutlass::complex( + static_cast(dist.sequential.delta)), + cutlass::complex( + static_cast(dist.sequential.start)) + ); + break; + case library::NumericTypeID::kS2: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS4: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS8: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS16: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS32: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kS64: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kB1: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU2: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU4: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU8: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU16: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU32: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + case library::NumericTypeID::kU64: + cutlass::reference::host::BlockFillSequential( + reinterpret_cast(host_data.data()), + capacity_, + static_cast(dist.sequential.delta), + static_cast(dist.sequential.start) + ); + break; + default: break; + } + + copy_from_host(host_data.data()); +} + void DeviceAllocation::initialize_random_sparsemeta_device(int seed, int MetaSizeInBits) { if (!bytes()) { #ifndef NDEBUG diff --git a/tools/profiler/src/device_allocation.h b/tools/profiler/src/device_allocation.h index f1362e7654..b5b3ee4af5 100644 --- a/tools/profiler/src/device_allocation.h +++ b/tools/profiler/src/device_allocation.h @@ -194,6 +194,12 @@ class DeviceAllocation { /// Initializes a host allocation to a random distribution using std::cout void initialize_random_host(int seed, Distribution dist); + /// Initializes a device allocation to a sequential distribution + void initialize_sequential_device(Distribution dist); + + /// Initializes a host allocation to a sequential distribution + void initialize_sequential_host(Distribution dist); + /// Initializes a device allocation to a random distribution using cuRAND void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index 43a5ebd3f3..cad454f8f1 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -123,15 +123,44 @@ DeviceAllocation *DeviceContext::allocate_tensor( } } + // Override pnz for the A/B/C tensors if overridden for Gaussian distributions + if (data_distribution.kind == Distribution::Gaussian) { + double mean = data_distribution.gaussian.mean; + double stddev = data_distribution.gaussian.stddev; + int scale = data_distribution.int_scale; + + if (name == "A" && data_distribution.gaussian.pnzA != 100.0) { + data_distribution.set_gaussian(mean, stddev, scale, data_distribution.gaussian.pnzA); + } + else if (name == "B" && data_distribution.gaussian.pnzB != 100.0) { + data_distribution.set_gaussian(mean, stddev, scale, data_distribution.gaussian.pnzB); + } + else if (name == "C" && data_distribution.gaussian.pnzC != 100.0) { + data_distribution.set_gaussian(mean, stddev, scale, data_distribution.gaussian.pnzC); + } + } + if (options.initialization.provider == library::Provider::kReferenceDevice) { - allocation->initialize_random_device( - options.initialization.seed + seed_shift, - data_distribution); + if (data_distribution.kind == Distribution::Sequential) { + allocation->initialize_sequential_device( + data_distribution); + } + else { + allocation->initialize_random_device( + options.initialization.seed + seed_shift, + data_distribution); + } } else if (options.initialization.provider == library::Provider::kReferenceHost) { - allocation->initialize_random_host( - options.initialization.seed + seed_shift, - data_distribution); + if (data_distribution.kind == Distribution::Sequential) { + allocation->initialize_sequential_host( + data_distribution); + } + else { + allocation->initialize_random_host( + options.initialization.seed + seed_shift, + data_distribution); + } } } diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index a622e048a1..8c8f8b2124 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -322,7 +322,6 @@ void GemmOperationProfiler::GemmProblem::initialize_result( set_argument(result, "split_k_mode", problem_space, library::to_string(split_k_mode)); set_argument(result, "split_k_slices", problem_space, split_k_slices); set_argument(result, "batch_count", problem_space, batch_count); - set_argument(result, "alpha", problem_space, library::lexical_cast(alpha, operation_desc.element_epilogue)); @@ -377,7 +376,6 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_.arguments.alpha = problem_.alpha.data(); gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; - // initialize reduction operation for parallel splitKMode if (problem_.split_k_mode == library::SplitKMode::kParallel) { if (!initialize_reduction_configuration_(operation, problem)) { @@ -492,7 +490,8 @@ Status GemmOperationProfiler::initialize_workspace( gemm_workspace_.problem_count = options.profiling.workspace_count; } - if (options.execution_mode != ExecutionMode::kDryRun) { + bool allocate_device_tensors = options.execution_mode != ExecutionMode::kDryRun; + if (allocate_device_tensors) { int seed_shift = 0; gemm_workspace_.A = device_context.allocate_tensor( options, @@ -544,7 +543,9 @@ Status GemmOperationProfiler::initialize_workspace( {int(problem_.ldc)}, problem_.batch_count * gemm_workspace_.problem_count ); + } + if (options.execution_mode != ExecutionMode::kDryRun) { // NOTE: the leading non-batch strides are duplicated here for 3.0 API kernels gemm_workspace_.arguments.problem_size = {int(problem_.m), int(problem_.n), int(problem_.k)}; gemm_workspace_.arguments.batch_count = problem_.batch_count; @@ -552,10 +553,10 @@ Status GemmOperationProfiler::initialize_workspace( gemm_workspace_.arguments.ldb = problem_.ldb; gemm_workspace_.arguments.ldc = problem_.ldc; gemm_workspace_.arguments.ldd = problem_.ldc; - gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); - gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); - gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); - gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); + gemm_workspace_.arguments.batch_stride_A = problem_.lda; + gemm_workspace_.arguments.batch_stride_B = problem_.ldb; + gemm_workspace_.arguments.batch_stride_C = problem_.ldc; + gemm_workspace_.arguments.batch_stride_D = problem_.ldc; /* Query device SM count to pass onto the kernel as an argument, where needed */ gemm_workspace_.arguments.sm_count = options.device.properties.multiProcessorCount; @@ -739,26 +740,39 @@ bool GemmOperationProfiler::verify_cutlass( } #endif // #if CUTLASS_ENABLE_CUBLAS - verify_with_reference_(options, report, device_context, operation, problem_space, problem); + bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem); // Update disposition to worst case verification outcome among all // verification providers which are supported bool is_any_verification_run_passed = false; - for(auto &m : results_.back().verification_map) { - if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { + for (auto &m : results_.back().verification_map) { + if (m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { results_.back().disposition = m.second; return true; } - if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { + if (!is_any_verification_run_passed && m.second == Disposition::kPassed) { is_any_verification_run_passed = true; } } - if(is_any_verification_run_passed) { + if (is_any_verification_run_passed) { results_.back().disposition = Disposition::kPassed; } } + // if verification.required is set, then return success iff at least one ref-check was run + if (options.verification.required) { + bool did_any_verification_run = false; + for (auto provider : options.verification.providers) { + did_any_verification_run |= (Disposition::kNotRun != results_.back().verification_map[provider]); + } + + if (not did_any_verification_run) { + results_.back().status = Status::kErrorNotSupported; + return false; + } + } + // Return true means continue profiling return true; } @@ -902,12 +916,7 @@ bool GemmOperationProfiler::verify_with_reference_( // Initialize state // - library::Provider references[] = { - library::Provider::kReferenceDevice, - library::Provider::kReferenceHost - }; - - for (auto provider : references) { + for (auto provider : options.verification.providers) { // Skip providers that are not enabled if (!options.verification.provider_enabled(provider)) { @@ -994,7 +1003,7 @@ bool GemmOperationProfiler::verify_with_reference_( if (status != Status::kSuccess) { results_.back().verification_map[provider] = Disposition::kNotRun; - return true; + continue; } results_.back().status = status; diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index 737821c19c..a3ed990ba0 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -386,6 +386,8 @@ int OperationProfiler::profile_all( operation, problem_space, problem); + + retval |= (not continue_profiling); } if (options.execution_mode == ExecutionMode::kDryRun) { diff --git a/tools/profiler/src/operation_profiler.h b/tools/profiler/src/operation_profiler.h index 17b4413c41..92bb41e359 100644 --- a/tools/profiler/src/operation_profiler.h +++ b/tools/profiler/src/operation_profiler.h @@ -39,6 +39,9 @@ #include #include +// CUTLASS includes +#include "cutlass/trace.h" + // CUTLASS Library includes #include "cutlass/library/library.h" #include "cutlass/library/util.h" diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 05f6530246..4bc03baee5 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -239,11 +239,20 @@ void Options::Initialization::get_distribution( {"max", &dist.uniform.max}, {"mean", &dist.gaussian.mean}, {"stddev", &dist.gaussian.stddev}, + {"pnzA", &dist.gaussian.pnzA}, + {"pnzB", &dist.gaussian.pnzB}, + {"pnzC", &dist.gaussian.pnzC}, {"start", &dist.sequential.start}, {"delta", &dist.sequential.delta}, {0, 0} }; + // Initalize pnz values to a default value of 100% + dist.gaussian.pnz = 100.0; + dist.gaussian.pnzA = 100.0; + dist.gaussian.pnzB = 100.0; + dist.gaussian.pnzC = 100.0; + using KeyValueVector = std::vector >; KeyValueVector values; @@ -302,7 +311,7 @@ void Options::Initialization::print_usage(std::ostream &out) const { << " --dist= " << " Data distribution of input tensors {uniform*, gaussian, identity, sequential}" << end_of_line << " --dist=uniform,min:,max:,scale:" << end_of_line - << " --dist=gaussian,mean:,stddev:,scale:" << end_of_line + << " --dist=gaussian,mean:,stddev:,scale:,pnzA:,pnzB:,pnzC:" << end_of_line << " --dist=sequential,start:,delta:,scale:" << end_of_line << " --dist=identity\n\n" @@ -340,7 +349,7 @@ Options::Library::Library(cutlass::CommandLine const &cmdline) { for (auto const & token : tokens) { if (token.find(":")) { - // todo - tokenized range + // TODO: tokenized range } else { int algo; @@ -473,6 +482,9 @@ size_t Options::Profiling::index(library::Provider provider) const { Options::Verification::Verification(cutlass::CommandLine const &cmdline) { cmdline.get_cmd_line_argument("verification-enabled", enabled, true); + if (enabled) { + cmdline.get_cmd_line_argument("verification-required", required, false); + } cmdline.get_cmd_line_argument("epsilon", epsilon, 0.05); diff --git a/tools/profiler/src/options.h b/tools/profiler/src/options.h index d679c70ee1..03dd71ee89 100644 --- a/tools/profiler/src/options.h +++ b/tools/profiler/src/options.h @@ -150,6 +150,10 @@ class Options { /// If true, kernels are verified before they are profiled bool enabled; + /// If true, causes profiler to return an error code if no reference check is run. + /// Only valid when verification is enabled. + bool required; + /// Relative error threshold - zero to require bit-level consistency double epsilon; diff --git a/tools/util/include/cutlass/util/distribution.h b/tools/util/include/cutlass/util/distribution.h index 7fee888452..d5557d952a 100644 --- a/tools/util/include/cutlass/util/distribution.h +++ b/tools/util/include/cutlass/util/distribution.h @@ -57,6 +57,10 @@ struct Distribution { struct { double mean; double stddev; + double pnz; + double pnzA; + double pnzB; + double pnzC; } gaussian; /// Elements are linear combination of row and column index @@ -88,10 +92,11 @@ struct Distribution { } /// Configures distribution as Gaussian distribution - Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0) { + Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 100.0) { kind = Gaussian; gaussian.mean = _mean; gaussian.stddev = _stddev; + gaussian.pnz = _pnz; int_scale = _int_scale; return *this; } @@ -123,7 +128,9 @@ inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const & out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max; break; case cutlass::Distribution::Gaussian: - out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev; + out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev + << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: " + << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC; break; case cutlass::Distribution::Identity: out << "identity"; diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index 9909ee9df8..4b2b8d152b 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -47,6 +47,7 @@ #include "cutlass/cutlass.h" #include "cutlass/tensor_ref.h" #include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" #include "device_memory.h" @@ -103,8 +104,17 @@ class HostTensor { /// Constant reference to element in tensor using ConstReference = typename ConstTensorRef::Reference; - /// Used to handle packing of subbyte elements - static int const kElementsPerStoredItem = (sizeof_bits::value < 8 ? (8 / sizeof_bits::value) : 1); + /// Note: Below is used to handle packing of subbyte elements + /// kBitsStoredVec : The bits of store vec that could be divisiable by the element + /// kElementsPerStoredVec : The number of elements could be stored in per store vec + /// kNumStoragePerStoredVec : How much storage(i.e. sizeof(element storage)) the store vec needs to consume. + /// Usually the element storage of subbyte is uint8_t. + /// Example + /// int2: kBitsStoredVec = 8; kElementsPerStoredVec = 4; kNumStoragePerStoredVec = 1 uint8_t; + /// int4: kBitsStoredVec = 8; kElementsPerStoredVec = 2; kNumStoragePerStoredVec = 1 uint8_t; + static int const kBitsStoredVec = (sizeof_bits::value < 8) ? cutlass::lcm(sizeof_bits::value, 8) : sizeof_bits::value; + static int const kElementsPerStoredVec = kBitsStoredVec / sizeof_bits::value; + static int const kNumStoragePerStoredVec = kBitsStoredVec / (sizeof(Element) * 8); private: @@ -170,8 +180,7 @@ class HostTensor { device_.reset(); host_.clear(); - count /= kElementsPerStoredItem; - + count = count / kElementsPerStoredVec * kNumStoragePerStoredVec; host_.resize(count); // Allocate memory @@ -217,7 +226,7 @@ class HostTensor { LongIndex new_size = size_t(layout_.capacity(extent_)); if (static_cast(new_size) > host_.size()) { - reserve(new_size); + reserve(new_size, device_backed_); } } @@ -232,7 +241,7 @@ class HostTensor { /// Returns the number of elements stored in the host tensor size_t size() const { - return host_.size() * kElementsPerStoredItem; + return host_.size() / kNumStoragePerStoredVec * kElementsPerStoredVec; } /// Returns the logical capacity based on extent and layout. May differ from size(). @@ -254,6 +263,9 @@ class HostTensor { /// Gets pointer to host data Element const * host_data() const { return host_.data(); } + /// Gets pointer to host data with a pointer offset + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_.data(), ptr_element_offset); } + /// Gets a constant reference to an element in host memory ConstReference host_data(LongIndex idx) const { return ReferenceFactory::get(host_data(), idx); @@ -262,11 +274,14 @@ class HostTensor { /// Gets pointer to device data Element * device_data() { return device_.get(); } + /// Gets pointer to device data + Element const * device_data() const { return device_.get(); } + /// Gets pointer to device data with a pointer offset Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } - /// Gets pointer to device data - Element const * device_data() const { return device_.get(); } + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } /// Accesses the tensor reference pointing to data TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } diff --git a/tools/util/include/cutlass/util/packed_stride.hpp b/tools/util/include/cutlass/util/packed_stride.hpp index 7ecffaffa1..b21582e0a8 100644 --- a/tools/util/include/cutlass/util/packed_stride.hpp +++ b/tools/util/include/cutlass/util/packed_stride.hpp @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Utilities for packing a rank-X shape into a rank-(X-1) stride in CuTe. + \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. */ #pragma once @@ -38,25 +38,29 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + // Strides without batch mode -template -cute::Stride> -make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, +template +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); auto s_copy = s; - cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); return s_copy; } -template -cute::Stride, StrideIntT> -make_cute_packed_stride(cute::Stride, StrideIntT> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, +template +cute::Stride, IntT> +make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); auto s_copy = s; - cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); return s_copy; } @@ -64,38 +68,40 @@ make_cute_packed_stride(cute::Stride, StrideIntT> s, cute::Shape -cute::Stride, int64_t> -make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, +template +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); auto s_copy = s; - cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); int batch_count = cute::get<2>(shape_MKL); if (batch_count > 1) { - cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); } else { - cute::get<2>(s_copy) = static_cast(0); + cute::get<2>(s_copy) = static_cast(0); } return s_copy; } -template -cute::Stride, StrideIntT, int64_t> -make_cute_packed_stride(cute::Stride, StrideIntT, int64_t> s, cute::Shape shape_MKL) { - static_assert(std::is_integral_v, +template +cute::Stride, IntT, int64_t> +make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, "Stride must have an integral type so it can be set dynamically. Static strides not supported."); auto s_copy = s; - cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); int batch_count = cute::get<2>(shape_MKL); if (batch_count > 1) { - cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); } else { - cute::get<2>(s_copy) = static_cast(0); + cute::get<2>(s_copy) = static_cast(0); } return s_copy; } ///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 1b5e62bca3..b066de321f 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -354,19 +354,22 @@ template < void TensorFillRandomGaussian( TensorView view, ///< destination tensor uint64_t seed, ///< seed for RNG - Element mean = Element(0), ///< Gaussian distribution's mean - Element stddev = Element(1), ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that + typename RealType::Type mean = Element(0), ///< Gaussian distribution's mean + typename RealType::Type stddev = Element(1), ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of /// data. - + cudaStream_t stream = nullptr) { + using RandomFunc = detail::RandomGaussianFunc; using Func = detail::TensorFillRandomGaussianFunc; using Params = typename Func::Params; TensorForEach( view.extent(), - Params(view, typename RandomFunc::Params(seed, mean, stddev, bits)) + Params(view, typename RandomFunc::Params(seed, mean, stddev, bits)), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -380,15 +383,16 @@ void BlockFillRandomGaussian( uint64_t seed, ///< seed for RNG typename RealType::Type mean, ///< Gaussian distribution's mean typename RealType::Type stddev, ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that + int bits = -1, ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of /// data. - + cudaStream_t stream = nullptr) { + using RandomFunc = detail::RandomGaussianFunc; typename RandomFunc::Params params(seed, mean, stddev, bits); - BlockForEach(ptr, capacity, params); + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -686,12 +690,13 @@ template < void TensorFillRandomUniform( TensorView view, ///< destination tensor uint64_t seed, ///< seed for RNG - Element max = Element(1), ///< upper bound of distribution - Element min = Element(0), ///< lower bound for distribution - int bits = -1) { ///< If non-negative, specifies number of fractional bits that + typename RealType::Type max = Element(1), ///< upper bound of distribution + typename RealType::Type min = Element(0), ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of - /// data. - + /// data. + cudaStream_t stream = nullptr) { + using RandomFunc = detail::RandomUniformFunc; using Func = detail::TensorFillRandomUniformFunc; using Params = typename Func::Params; @@ -700,7 +705,9 @@ void TensorFillRandomUniform( TensorForEach( view.extent(), - Params(view, random) + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -714,15 +721,16 @@ void BlockFillRandomUniform( uint64_t seed, ///< seed for RNG typename RealType::Type max, ///< upper bound of distribution typename RealType::Type min, ///< lower bound for distribution - int bits = -1) { ///< If non-negative, specifies number of fractional bits that + int bits = -1, ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of - /// data. - + /// data. + cudaStream_t stream = nullptr) { + using RandomFunc = detail::RandomUniformFunc; - + typename RandomFunc::Params params(seed, max, min, bits); - BlockForEach(ptr, capacity, params); + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -899,10 +907,9 @@ template < void TensorFillRandomSparseMeta( TensorView view, ///< destination tensor uint64_t seed, ///< seed for RNG - int MetaSizeInBits = 2) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. - + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + using RandomFunc = detail::RandomSparseMetaFunc; using Func = detail::TensorFillRandomUniformFunc; using Params = typename Func::Params; @@ -911,7 +918,9 @@ void TensorFillRandomSparseMeta( TensorForEach( view.extent(), - Params(view, random) + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -923,13 +932,14 @@ void BlockFillRandomSparseMeta( Element *ptr, size_t capacity, uint64_t seed, ///< seed for RNG - int MetaSizeInBits = 2) { ///< meta data size - + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + using RandomFunc = detail::RandomSparseMetaFunc; - + typename RandomFunc::Params params(seed, MetaSizeInBits); - BlockForEach(ptr, capacity, params); + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1245,14 +1255,17 @@ template < void TensorFillDiagonal( TensorView view, ///< destination tensor Element diag = Element(1), ///< value to write in the diagonal - Element other = Element(0)) { ///< value to write off the diagonal - + Element other = Element(0), ///< value to write off the diagonal + cudaStream_t stream = nullptr) { + typedef detail::TensorFillDiagonalFunc Func; typedef typename Func::Params Params; TensorForEach( view.extent(), - Params(view, diag, other) + Params(view, diag, other), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -1264,18 +1277,20 @@ template < void TensorFillPartial( TensorView view, ///< destination tensor Element element, - FillMode fill_mode) { - + FillMode fill_mode, + cudaStream_t stream = nullptr) { + typedef detail::TensorFillPartialFunc Func; typedef typename Func::Params Params; TensorForEach( view.extent(), - Params(view, element, fill_mode) + Params(view, element, fill_mode), + stream ); } -/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side +/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side /// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) template < typename Element, ///< Element type @@ -1284,14 +1299,17 @@ void TensorClearPartial( TensorView view, ///< destination tensor Element element, FillMode fill_mode, - int alignment) { - + int alignment, + cudaStream_t stream = nullptr) { + typedef detail::TensorClearPartialFunc Func; typedef typename Func::Params Params; TensorForEach( view.extent(), - Params(view, element, fill_mode, alignment) + Params(view, element, fill_mode, alignment), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -1302,10 +1320,11 @@ template < typename Element, ///< Element type typename Layout> ///< Layout function void TensorFill( - TensorView view, ///< destination tensor - Element val = Element(0)) { ///< value to uniformly fill it with + TensorView view, ///< destination tensor + Element val = Element(0), ///< value to uniformly fill it with + cudaStream_t stream = nullptr) { - TensorFillDiagonal(view, val, val); + TensorFillDiagonal(view, val, val, stream); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1315,9 +1334,10 @@ template < typename Element, ///< Element type typename Layout> ///< Layout function void TensorFillIdentity( - TensorView view) { ///< destination tensor + TensorView view, ///< destination tensor + cudaStream_t stream = nullptr) { - TensorFillDiagonal(view, Element(1), Element(0)); + TensorFillDiagonal(view, Element(1), Element(0), stream); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1415,14 +1435,17 @@ template < typename Layout> ///< Layout function void TensorUpdateDiagonal( TensorView view, ///< destination tensor - Element diag = Element(1)) { + Element diag = Element(1), + cudaStream_t stream = nullptr) { typedef detail::TensorUpdateDiagonalFunc Func; typedef typename Func::Params Params; TensorForEach( view.extent(), - Params(view, diag) + Params(view, diag), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -1521,14 +1544,17 @@ template < typename Layout> ///< Layout function void TensorUpdateOffDiagonal( TensorView view, ///< destination tensor - Element other = Element(1)) { + Element other = Element(1), + cudaStream_t stream = nullptr) { typedef detail::TensorUpdateOffDiagonalFunc Func; typedef typename Func::Params Params; TensorForEach( view.extent(), - Params(view, other) + Params(view, other), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -1602,11 +1628,30 @@ struct TensorFillLinearFunc { /// Compute random value and update RNG state CUTLASS_DEVICE void operator()(TensorCoord const &coord) { + Element sum = params.s; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < Layout::kRank; ++i) { - sum += params.v[i] * Element(coord[i]); + if constexpr (is_complex::value) { + if constexpr (sizeof_bits::value <= 32) { + sum = Element(static_cast>(sum) + + static_cast>(params.v[i]) * static_cast>(coord[i])); + } + } + else if constexpr (sizeof_bits::value <= 32) { + if constexpr (std::numeric_limits::is_integer) { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + else { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + } + else { + sum += params.v[i] * coord[i]; + } } params.view.at(coord) = sum; @@ -1624,20 +1669,57 @@ template < void TensorFillLinear( TensorView view, ///< destination tensor Array const & v, - Element s = Element(0)) { + Element s = Element(0), + cudaStream_t stream = nullptr) { using Func = detail::TensorFillLinearFunc; using Params = typename Func::Params; TensorForEach( view.extent(), - Params(view, v, s) + Params(view, v, s), + /*grid_size*/0, /*block_size*/0, + stream ); } /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + stream); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + /// Fills a block of data with sequential elements template < typename Element @@ -1670,27 +1752,30 @@ void BlockFillRandom( Element *ptr, size_t capacity, uint64_t seed, - Distribution dist) { + Distribution dist, + cudaStream_t stream = nullptr) { using Real = typename RealType::Type; if (dist.kind == Distribution::Gaussian) { BlockFillRandomGaussian( - ptr, - capacity, - seed, - static_cast(dist.gaussian.mean), - static_cast(dist.gaussian.stddev), - dist.int_scale); + ptr, + capacity, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + stream); } else if (dist.kind == Distribution::Uniform) { BlockFillRandomUniform( - ptr, - capacity, - seed, + ptr, + capacity, + seed, static_cast(dist.uniform.max), - static_cast(dist.uniform.min), - dist.int_scale); + static_cast(dist.uniform.min), + dist.int_scale, + stream); } } @@ -1786,14 +1871,17 @@ template < typename Layout> ///< Layout function void TensorCopyDiagonalIn( TensorView view, ///< destination tensor - Element const *ptr) { ///< dense buffer of elements + Element const *ptr, ///< dense buffer of elements + cudaStream_t stream = nullptr) { using Func = detail::TensorCopyDiagonalInFunc; using Params = typename Func::Params; TensorForEach( view.extent(), - Params(view, ptr) + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream ); } @@ -1890,14 +1978,17 @@ template < typename Layout> ///< Layout function void TensorCopyDiagonalOut( Element *ptr, ///< dense buffer of elements - TensorView view) { ///< source tensor + TensorView view, ///< source tensor + cudaStream_t stream = nullptr) { using Func = detail::TensorCopyDiagonalOutFunc; using Params = typename Func::Params; TensorForEach( view.extent(), - Params(view, ptr) + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream ); } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h index bb6f935ed2..bae68e7037 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -45,7 +45,10 @@ template struct TensorForEach { /// Constructor performs the operation. - TensorForEach(Coord size, Params params = Params(), int grid_size = 0, int block_size = 0) { + TensorForEach( + Coord size, Params params = Params(), + int grid_size = 0, int block_size = 0, + cudaStream_t stream = nullptr) { if (!grid_size || !block_size) { @@ -67,7 +70,7 @@ struct TensorForEach { dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); - kernel::TensorForEach<<< grid, block >>>(size, params); + kernel::TensorForEach<<< grid, block, 0, stream >>>(size, params); } }; @@ -78,7 +81,10 @@ template struct TensorDiagonalForEach { /// Constructor performs the operation - TensorDiagonalForEach(Coord size, Params params = Params(), int start = 0, int end = -1, int block_size = 128) { + TensorDiagonalForEach( + Coord size, Params params = Params(), + int start = 0, int end = -1, + int block_size = 128, cudaStream_t stream = nullptr) { if (end < 0) { end = size.min(); @@ -87,7 +93,8 @@ struct TensorDiagonalForEach { dim3 block(block_size, 1, 1); dim3 grid((end - start + block_size - 1) / block_size, 1, 1); - kernel::TensorDiagonalForEach<<< grid, block >>>(size, params, start, end); + kernel::TensorDiagonalForEach<<< grid, block, 0, stream >>>( + size, params, start, end); } }; @@ -99,11 +106,12 @@ struct BlockForEach { /// Constructor performs the operation. BlockForEach( - Element *ptr, + Element *ptr, size_t capacity, typename Func::Params params = typename Func::Params(), - int grid_size = 0, - int block_size = 0) { + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { if (!grid_size || !block_size) { @@ -125,7 +133,7 @@ struct BlockForEach { dim3 grid(grid_size, 1, 1); dim3 block(block_size, 1, 1); - kernel::BlockForEach<<< grid, block >>>(ptr, capacity, params); + kernel::BlockForEach<<< grid, block, 0, stream >>>(ptr, capacity, params); } }; diff --git a/tools/util/include/cutlass/util/reference/host/gemm.h b/tools/util/include/cutlass/util/reference/host/gemm.h index f70e069966..85cf51c930 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm.h +++ b/tools/util/include/cutlass/util/reference/host/gemm.h @@ -372,6 +372,84 @@ struct Gemm +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + //////////////////////////////////////////////////////////////////////////////////////////////////// // // Batched GEMM diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index f87e3d8e9a..6152f117fc 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -73,23 +73,29 @@ struct GettMainloopParams { template< class ElementScalar_, + class ElementScalingFactor_, class ElementAccumulator_, class ElementCompute_, class TensorC_, // (M, N, L) class TensorD_, // (M, N, L) - class TensorBias_, // (M, 1) - class TensorT_, // (M, N, L) + class VectorBias_ = TensorD_, // (M, 1) + class TensorAux_ = TensorD_, // (M, N, L) + class VectorAlpha_ = TensorD_, // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) class ActivationFunctor_ = cutlass::epilogue::thread::Identity, class BiasBinaryOp_ = cutlass::plus > struct GettEpilogueParams { using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; using ElementAccumulator = ElementAccumulator_; using ElementCompute = ElementCompute_; using TensorC = TensorC_; using TensorD = TensorD_; - using TensorBias = TensorBias_; - using TensorT = TensorT_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; using ActivationFunctor = ActivationFunctor_; using BiasBinaryOp = BiasBinaryOp_; @@ -97,17 +103,27 @@ struct GettEpilogueParams { using LayoutC = typename TensorC::layout_type; using EngineD = typename TensorD::engine_type; using LayoutD = typename TensorD::layout_type; - using EngineBias = typename TensorBias::engine_type; - using LayoutBias = typename TensorBias::layout_type; - using EngineT = typename TensorT::engine_type; - using LayoutT = typename TensorT::layout_type; + ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); TensorC C{}; TensorD D{}; - TensorBias Bias{}; - TensorT T{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -125,7 +141,9 @@ void Gett( static int constexpr kBlockM = 64; static int constexpr kBlockN = 64; +#if defined(_OPENMP) #pragma omp parallel for collapse(3) +#endif for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { @@ -152,8 +170,8 @@ void gett_mainloop( static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - using ElementA = typename MainloopParams::EngineA::value_type; - using ElementB = typename MainloopParams::EngineB::value_type; + using ElementA = typename MainloopParams::TensorA::value_type; + using ElementB = typename MainloopParams::TensorB::value_type; using RingOp = multiply_add; RingOp fma_op; @@ -217,16 +235,23 @@ void gett_epilogue( static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); using ElementCompute = typename EpilogueParams::ElementCompute; - using ElementC = typename EpilogueParams::EngineC::value_type; - - using ElementD = typename EpilogueParams::EngineD::value_type; - using ElementBias = typename EpilogueParams::EngineBias::value_type; - using ElementT = typename EpilogueParams::EngineT::value_type; - + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; using ActivationFunctor = typename EpilogueParams::ActivationFunctor; using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + constexpr bool IsScalingAndAmaxOutputNeeded = + std::is_same_v or + std::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + std::is_same_v or + std::is_same_v; + // Input related converter NumericConverter accumulator_converter; NumericConverter source_converter; @@ -234,9 +259,15 @@ void gett_epilogue( // Scale related converter NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + // Output related converter NumericConverter destination_converter; - NumericConverter temporary_converter; + NumericConverter aux_destination_converter; + // Epilogue operations multiply_add epilogue_fma; multiplies mul; @@ -250,13 +281,30 @@ void gett_epilogue( // Do conversion ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + for (int n_b = 0; n_b < kBlockN; ++n_b) { for (int m_b = 0; m_b < kBlockM; ++m_b) { if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { // Convert every type to ElementCompute first, do compute, convert to output type, write it out ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); - + // per-row alpha + if (epilogue_params.Valpha.data()) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); + } ElementCompute output = mul(converted_alpha, converted_acc); + if (epilogue_params.Bias.data()) { ElementCompute converted_bias = bias_converter(epilogue_params.Bias(m + m_b)); output = bias_op(output, converted_bias); @@ -264,20 +312,54 @@ void gett_epilogue( if (epilogue_params.C.data()) { ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // per-row beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b)); + } output = epilogue_fma(converted_beta, converted_src, output); } - if (epilogue_params.T.data()) { - // Store intermediate output - epilogue_params.T(m + m_b, n + n_b, l) = temporary_converter(output); + if (epilogue_params.Aux.data()) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); } output = activation(output); + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output); } } } +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -310,15 +392,19 @@ void Gemm3x( Layout layout_B = make_layout_rank3(mainloop_params.B); Layout layout_C = make_layout_rank3(epilogue_params.C); Layout layout_D = make_layout_rank3(epilogue_params.D); + Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); - Layout layout_T = make_layout_rank3(epilogue_params.T); - + Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); + Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); - auto TensorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); - auto TensorT = make_tensor(epilogue_params.T.data(), layout_T); + auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux); + auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); + auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha); + auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta); // Reconstruct mainloop params GettMainloopParams epilogue_params_converted{epilogue_params.alpha, epilogue_params.beta, TensorC, TensorD, - TensorBias, - TensorT + VectorBias, + TensorAux, + VectorAlpha, + VectorBeta, + epilogue_params.abs_amax_D, + epilogue_params.abs_amax_Aux, + epilogue_params.scale_a, + epilogue_params.scale_b, + epilogue_params.scale_c, + epilogue_params.scale_d, + epilogue_params.scale_aux }; Gett(mainloop_params_converted, epilogue_params_converted); diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 3db176edbf..9b0dcdb374 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -38,6 +38,7 @@ #include #include #include +#include // Cutlass includes #include "cutlass/cutlass.h" @@ -156,6 +157,7 @@ struct RandomGaussianFunc { double stddev; int int_scale; double pi; + double pnz; // // Methods @@ -164,9 +166,10 @@ struct RandomGaussianFunc { uint64_t seed_ = 0, double mean_ = 0, double stddev_ = 1, - int int_scale_ = -1 + int int_scale_ = -1, + double pnz_ = 100.0 ): - seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_) { std::srand((unsigned)seed); } @@ -184,12 +187,24 @@ struct RandomGaussianFunc { // Scale and convert final result Element result; - if (int_scale >= 0) { - rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); - result = static_cast(rnd); + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz / 100); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } } else { - result = static_cast(rnd); + result = static_cast(0); } return result; @@ -205,6 +220,7 @@ struct RandomGaussianFunc > { double stddev; int int_scale; double pi; + double pnz; // // Methods @@ -213,9 +229,10 @@ struct RandomGaussianFunc > { uint64_t seed_ = 0, double mean_ = 0, double stddev_ = 1, - int int_scale_ = -1 + int int_scale_ = -1, + double pnz_ = 100.0 ): - seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_) { std::srand((unsigned)seed); } @@ -228,14 +245,28 @@ struct RandomGaussianFunc > { detail::BoxMullerFunc func; func(rnd, mean, stddev, pi); - if (int_scale >= 0) { - rnd[0] = double(int(rnd[0] * double(1 << int_scale))); - rnd[1] = double(int(rnd[1] * double(1 << int_scale))); - reals[0] = from_real(rnd[0] / double(1 << int_scale)); - reals[1] = from_real(rnd[1] / double(1 << int_scale)); - } else { - reals[0] = from_real(rnd[0]); - reals[1] = from_real(rnd[1]); + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz / 100); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd[0] = double(int(rnd[0] * double(1 << int_scale))); + rnd[1] = double(int(rnd[1] * double(1 << int_scale))); + reals[0] = from_real(rnd[0] / double(1 << int_scale)); + reals[1] = from_real(rnd[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd[0]); + reals[1] = from_real(rnd[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); } return complex(reals[0], reals[1]); @@ -251,6 +282,7 @@ struct RandomGaussianFunc > { double stddev; int int_scale; double pi; + double pnz; // // Methods @@ -259,9 +291,10 @@ struct RandomGaussianFunc > { uint64_t seed_ = 0, double mean_ = 0, double stddev_ = 1, - int int_scale_ = -1 + int int_scale_ = -1, + double pnz_ = 100.0 ): - seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_) { std::srand((unsigned)seed); } @@ -276,21 +309,37 @@ struct RandomGaussianFunc > { func(rnd1, mean, stddev, pi); func(rnd2, mean, stddev, pi); - if (int_scale >= 0) { - rnd1[0] = double(int(rnd1[0] * double(1 << int_scale))); - rnd1[1] = double(int(rnd1[1] * double(1 << int_scale))); - rnd2[0] = double(int(rnd2[0] * double(1 << int_scale))); - rnd2[1] = double(int(rnd2[1] * double(1 << int_scale))); - - reals[0] = from_real(rnd1[0] / double(1 << int_scale)); - reals[1] = from_real(rnd1[1] / double(1 << int_scale)); - reals[2] = from_real(rnd2[0] / double(1 << int_scale)); - reals[3] = from_real(rnd2[1] / double(1 << int_scale)); - } else { - reals[0] = from_real(rnd1[0]); - reals[1] = from_real(rnd1[1]); - reals[2] = from_real(rnd2[0]); - reals[3] = from_real(rnd2[1]); + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz / 100); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd1[0] = double(int(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(int(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(int(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(int(rnd2[1] * double(1 << int_scale))); + + reals[0] = from_real(rnd1[0] / double(1 << int_scale)); + reals[1] = from_real(rnd1[1] / double(1 << int_scale)); + reals[2] = from_real(rnd2[0] / double(1 << int_scale)); + reals[3] = from_real(rnd2[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd1[0]); + reals[1] = from_real(rnd1[1]); + reals[2] = from_real(rnd2[0]); + reals[3] = from_real(rnd2[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + reals[2] = from_real(0); + reals[3] = from_real(0); } return Quaternion(reals[0], reals[1], reals[2], reals[3]); @@ -389,11 +438,11 @@ void TensorFillRandomGaussian( uint64_t seed, ///< seed for RNG double mean = 0, ///< Gaussian distribution's mean double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 100.0) { /// are not truncated to zero. Permits reducing precision of /// data. - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); detail::TensorFillGaussianFunc func( dst, @@ -411,16 +460,16 @@ template < typename Element, ///< Element type typename Layout> ///< Layout function void TensorFillRandomGaussian( - TensorViewPlanarComplex dst, ///< destination tensor - uint64_t seed, ///< seed for RNG - double mean = 0, ///< Gaussian distribution's mean - double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of - /// data. + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 100.0) { /// are not truncated to zero. Permits reducing precision of + /// data. - TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits); - TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits); + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz); } /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -434,11 +483,11 @@ void TensorFillSymmetricRandomGaussian( cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices double mean = 0, ///< Gaussian distribution's mean double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 100.0) { /// are not truncated to zero. Permits reducing precision of /// data. - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); detail::TensorFillSymmetricGaussianFunc func( dst, @@ -464,12 +513,12 @@ void BlockFillRandomGaussian( uint64_t seed, ///< seed for RNG double mean = 0, ///< Gaussian distribution's mean double stddev = 1, ///< Gaussian distribution's standard deviation - int bits = -1) { ///< If non-negative, specifies number of fractional bits that - /// are not truncated to zero. Permits reducing precision of + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 100.0) { /// are not truncated to zero. Permits reducing precision of /// data. - detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); for (size_t i = 0; i < capacity; ++i) { ReferenceFactory::get(ptr, i) = random_func(); @@ -801,10 +850,10 @@ void TensorFillRandomUniform( uint64_t seed, ///< seed for RNG double max = 1, ///< upper bound of distribution double min = 0, ///< lower bound for distribution - int bits = -1) { ///< If non-negative, specifies number of fractional bits that + int bits = -1) { ///< If non-negative, specifies number of fractional bits that /// are not truncated to zero. Permits reducing precision of - /// data. - + /// data. + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits); TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits); } @@ -1189,6 +1238,37 @@ void TensorFillSequential( /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + /// Fills a block of data with sequential elements template < typename Element @@ -1250,7 +1330,8 @@ void BlockFillRandom( seed, dist.gaussian.mean, dist.gaussian.stddev, - dist.int_scale); + dist.int_scale, + dist.gaussian.pnz); } else if (dist.kind == Distribution::Uniform) { BlockFillRandomUniform(