From c53f3339bb95ba52c5571134b1b3407e95a2546b Mon Sep 17 00:00:00 2001 From: Andrew Kerr Date: Wed, 23 Sep 2020 14:00:58 -0700 Subject: [PATCH] CUTLASS 2.3 initial commit (#134) CUTLASS 2.3 adds GEMMs targeting Sparse Tensor Cores on the NVIDIA Ampere Architecture, fast SGEMM, and small matrix classes, bug fixes, and performance enhancements. --- CHANGELOG.md | 14 + CMakeLists.txt | 16 +- CONTRIBUTORS.md | 5 + CUDA.cmake | 9 +- README.md | 41 +- examples/02_dump_reg_shmem/dump_reg_shmem.cu | 7 +- examples/10_planar_complex/planar_complex.cu | 9 +- .../planar_complex_array.cu | 9 +- .../b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h | 205 + .../b2b_interleaved_gemm_run.h | 40 +- examples/13_fused_two_gemms/fused_gemm.cu | 9 +- examples/13_fused_two_gemms/kernel/b2b_gemm.h | 6 +- .../kernel/default_b2b_gemm.h | 80 +- .../threadblock/b2b_mma_multistage.h | 862 + .../threadblock/b2b_mma_pipelined.h | 4 - .../threadblock/default_b2b_mma.h | 112 +- .../CMakeLists.txt | 27 + .../ampere_tf32_tensorop_gemm.cu | 278 + .../CMakeLists.txt | 27 + .../ampere_sparse_tensorop_gemm.cu | 311 + examples/CMakeLists.txt | 2 + include/cutlass/arch/arch.h | 11 + include/cutlass/arch/memory.h | 3 +- include/cutlass/arch/mma.h | 37 + include/cutlass/arch/mma_sm50.h | 12 +- include/cutlass/arch/mma_sm60.h | 6 +- include/cutlass/arch/mma_sm61.h | 4 +- include/cutlass/arch/mma_sm80.h | 1 - include/cutlass/arch/simd.h | 2 +- include/cutlass/arch/sp_mma_sm80.h | 1591 ++ include/cutlass/array.h | 40 + include/cutlass/bfloat16.h | 2 +- include/cutlass/complex.h | 8 +- include/cutlass/constants.h | 1233 ++ include/cutlass/coord.h | 6 + include/cutlass/core_io.h | 31 +- include/cutlass/cutlass.h | 11 +- include/cutlass/epilogue/thread/activation.h | 37 +- .../epilogue/thread/linear_combination.h | 14 + .../thread/linear_combination_clamp.h | 29 +- .../epilogue/thread/linear_combination_gelu.h | 206 + .../default_epilogue_complex_tensor_op.h | 8 +- .../threadblock/default_epilogue_tensor_op.h | 1 - .../threadblock/default_thread_map_simt.h | 2 +- .../default_thread_map_tensor_op.h | 4 +- .../default_thread_map_volta_tensor_op.h | 4 +- .../default_thread_map_wmma_tensor_op.h | 2 +- .../cutlass/epilogue/threadblock/epilogue.h | 13 +- .../threadblock/predicated_tile_iterator.h | 13 +- .../cutlass/epilogue/warp/tensor_op_policy.h | 6 +- .../epilogue/warp/tile_iterator_tensor_op.h | 252 +- .../warp/tile_iterator_volta_tensor_op.h | 1 + include/cutlass/fast_math.h | 189 +- include/cutlass/functional.h | 78 +- include/cutlass/gemm/device/gemm_sparse.h | 517 + .../cutlass/gemm/device/gemm_universal_base.h | 75 +- include/cutlass/gemm/gemm.h | 4 +- .../cutlass/gemm/kernel/default_gemm_sparse.h | 187 + include/cutlass/gemm/kernel/gemm.h | 6 +- include/cutlass/gemm/kernel/gemm_array.h | 6 +- include/cutlass/gemm/kernel/gemm_batched.h | 6 +- include/cutlass/gemm/kernel/gemm_pipelined.h | 4 +- .../cutlass/gemm/kernel/gemm_planar_complex.h | 6 +- .../gemm/kernel/gemm_planar_complex_array.h | 3 +- .../gemm/kernel/gemm_splitk_parallel.h | 6 +- include/cutlass/gemm/kernel/gemm_universal.h | 22 +- include/cutlass/gemm/kernel/sparse_gemm.h | 392 + include/cutlass/gemm/thread/mma_sm50.h | 10 + include/cutlass/gemm/thread/mma_sm60.h | 52 +- include/cutlass/gemm/thread/mma_sm61.h | 63 +- .../cutlass/gemm/threadblock/default_mma.h | 1 + .../gemm/threadblock/default_mma_core_sm50.h | 197 - .../gemm/threadblock/default_mma_core_sm75.h | 11 +- .../gemm/threadblock/default_mma_core_sm80.h | 11 +- .../default_mma_core_sparse_sm80.h | 828 + .../gemm/threadblock/default_sparse_mma.h | 190 + .../gemm/threadblock/mma_sparse_base.h | 259 + .../gemm/threadblock/mma_sparse_multistage.h | 667 + .../gemm/threadblock/threadblock_swizzle.h | 22 +- .../gemm/warp/default_mma_sparse_tensor_op.h | 159 + include/cutlass/gemm/warp/mma_simt.h | 3 + .../cutlass/gemm/warp/mma_sparse_tensor_op.h | 335 + include/cutlass/gemm/warp/mma_tensor_op.h | 54 +- .../warp/mma_tensor_op_fragment_iterator.h | 70 +- .../cutlass/gemm/warp/mma_tensor_op_sm70.h | 31 +- .../gemm/warp/mma_tensor_op_tile_iterator.h | 77 +- .../warp/mma_tensor_op_tile_iterator_sm80.h | 833 + .../warp/mma_tensor_op_tile_iterator_sparse.h | 374 + include/cutlass/half.h | 2 +- include/cutlass/integer_subbyte.h | 23 +- include/cutlass/layout/matrix.h | 19 +- include/cutlass/layout/tensor.h | 116 +- .../layout/tensor_op_multiplicand_sm75.h | 16 +- include/cutlass/matrix.h | 14111 ++++++++++++++++ include/cutlass/numeric_conversion.h | 217 +- include/cutlass/quaternion.h | 616 + include/cutlass/real.h | 5 + include/cutlass/reduction/batched_reduction.h | 179 - .../reduction/batched_reduction_traits.h | 192 - include/cutlass/relatively_equal.h | 12 + include/cutlass/tensor_coord.h | 142 + include/cutlass/tensor_view.h | 6 + include/cutlass/tfloat32.h | 2 +- include/cutlass/{matrix_traits.h => trace.h} | 38 +- .../transform/pitch_linear_thread_map.h | 66 +- .../predicated_tile_access_iterator.h | 6 +- .../regular_tile_access_iterator_tensor_op.h | 10 +- .../regular_tile_iterator_pitch_linear.h | 9 +- media/docs/functionality.md | 46 + media/docs/gemm_api.md | 27 + media/docs/profiler.md | 2 +- media/docs/quickstart.md | 14 +- test/unit/CMakeLists.txt | 1 + test/unit/core/CMakeLists.txt | 2 + test/unit/core/bfloat16.cu | 3 + test/unit/core/complex.cu | 73 +- test/unit/core/half.cu | 3 + test/unit/core/matrix.cu | 198 + test/unit/core/quaternion.cu | 162 + test/unit/core/tensor_view.cu | 12 +- test/unit/core/tfloat32.cu | 3 + test/unit/gemm/device/CMakeLists.txt | 15 + ...16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu | 266 + ...16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu | 267 + ...16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu | 265 + ...16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu | 266 + ...16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu | 267 + ...16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu | 265 + ...16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu | 193 + ...32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu | 423 + ...32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu | 423 + ...32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu | 422 + ...32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu | 423 + ..._s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu | 261 + .../gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu | 355 + ..._s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu | 263 + test/unit/gemm/device/testbed_sparse.h | 440 + test/unit/gemm/threadblock/CMakeLists.txt | 5 + test/unit/gemm/threadblock/mma_multistage.cu | 3827 +++++ .../gemm/threadblock/mma_multistage_sparse.cu | 2697 +++ .../mma_multistage_sparse_testbed.h | 398 + .../gemm/threadblock/mma_multistage_testbed.h | 329 + .../gemm/threadblock/mma_pipelined_sm80.cu | 563 + .../threadblock/mma_planar_complex_sm80.cu | 73 + test/unit/gemm/warp/gemm_complex_sm80.cu | 61 +- test/unit/gemm/warp/gemm_sm80.cu | 76 + test/unit/gemm/warp/gemm_sparse_sm80.cu | 1101 ++ test/unit/gemm/warp/testbed.h | 357 +- test/unit/layout/matrix.cu | 5 +- test/unit/util/CMakeLists.txt | 26 + test/unit/util/complex.cu | 102 - test/unit/util/tensor_reduce.cu | 238 + tools/CMakeLists.txt | 9 +- tools/library/CMakeLists.txt | 4 + .../include/cutlass/library/arch_mappings.h | 99 + .../library/include/cutlass/library/library.h | 82 +- .../include/cutlass/library/manifest.h | 2 +- tools/library/scripts/gemm_operation.py | 90 +- tools/library/scripts/generator.py | 655 +- tools/library/scripts/library.py | 10 +- tools/library/scripts/manifest.py | 51 +- tools/library/src/gemm_operation.h | 248 +- tools/library/src/library_internal.h | 67 +- tools/library/src/manifest.cpp | 1 - tools/library/src/reference/gemm.cu | 335 + .../src/reference/gemm_reference_operation.h | 472 + .../initialize_reference_operations.cu | 53 + tools/library/src/util.cu | 41 +- tools/profiler/CMakeLists.txt | 1 + tools/profiler/src/cublas_helpers.cpp | 99 +- tools/profiler/src/cublas_helpers.h | 3 + tools/profiler/src/cutlass_profiler.cu | 4 + tools/profiler/src/device_allocation.cu | 470 +- tools/profiler/src/device_allocation.h | 34 +- tools/profiler/src/device_context.cu | 62 +- tools/profiler/src/device_context.h | 18 +- tools/profiler/src/gemm_operation_profiler.cu | 381 +- tools/profiler/src/gemm_operation_profiler.h | 37 +- tools/profiler/src/operation_profiler.cu | 21 +- tools/profiler/src/operation_profiler.h | 5 +- tools/profiler/src/options.cu | 33 +- tools/profiler/src/options.h | 7 + tools/profiler/src/performance_report.cpp | 4 + tools/profiler/src/problem_space.cpp | 84 +- tools/profiler/src/problem_space.h | 101 +- .../src/sparse_gemm_operation_profiler.cu | 560 + .../src/sparse_gemm_operation_profiler.h | 208 + tools/util/include/cutlass/util/exceptions.h | 7 +- .../util/include/cutlass/util/host_reorder.h | 30 + tools/util/include/cutlass/util/host_tensor.h | 1 - .../cutlass/util/host_tensor_planar_complex.h | 1 - .../include/cutlass/util/host_uncompress.h | 117 + .../reference/detail/linear_to_coordinate.h | 88 + .../cutlass/util/reference/device/gemm.h | 1 - .../util/reference/device/gemm_complex.h | 295 + .../reference/device/gemm_planar_complex.h | 1 - .../util/reference/device/kernel/gemm.h | 1 - .../reference/device/kernel/tensor_foreach.h | 4 +- .../util/reference/device/tensor_compare.h | 11 +- .../util/reference/device/tensor_fill.h | 285 +- .../util/reference/device/tensor_foreach.h | 2 +- .../util/reference/device/tensor_reduce.h | 505 + .../util/reference/device/thread/gemm.h | 1 - .../cutlass/util/reference/host/gemm.h | 1 - .../util/reference/host/gemm_complex.h | 91 +- .../util/reference/host/gemm_planar_complex.h | 1 - .../cutlass/util/reference/host/tensor_fill.h | 142 +- .../cutlass/util/reference/host/tensor_norm.h | 44 +- .../util/reference/host/tensor_reduce.h | 197 + 209 files changed, 46919 insertions(+), 1674 deletions(-) create mode 100644 examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h create mode 100644 examples/13_fused_two_gemms/threadblock/b2b_mma_multistage.h create mode 100644 examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt create mode 100644 examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu create mode 100644 examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt create mode 100644 examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu create mode 100644 include/cutlass/arch/sp_mma_sm80.h create mode 100644 include/cutlass/constants.h create mode 100644 include/cutlass/epilogue/thread/linear_combination_gelu.h create mode 100644 include/cutlass/gemm/device/gemm_sparse.h create mode 100644 include/cutlass/gemm/kernel/default_gemm_sparse.h create mode 100644 include/cutlass/gemm/kernel/sparse_gemm.h delete mode 100644 include/cutlass/gemm/threadblock/default_mma_core_sm50.h create mode 100644 include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h create mode 100644 include/cutlass/gemm/threadblock/default_sparse_mma.h create mode 100644 include/cutlass/gemm/threadblock/mma_sparse_base.h create mode 100644 include/cutlass/gemm/threadblock/mma_sparse_multistage.h create mode 100644 include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h create mode 100644 include/cutlass/gemm/warp/mma_sparse_tensor_op.h create mode 100644 include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h create mode 100644 include/cutlass/matrix.h create mode 100644 include/cutlass/quaternion.h delete mode 100644 include/cutlass/reduction/batched_reduction.h delete mode 100644 include/cutlass/reduction/batched_reduction_traits.h rename include/cutlass/{matrix_traits.h => trace.h} (79%) create mode 100644 test/unit/core/matrix.cu create mode 100644 test/unit/core/quaternion.cu create mode 100644 test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu create mode 100644 test/unit/gemm/device/testbed_sparse.h create mode 100644 test/unit/gemm/threadblock/mma_multistage.cu create mode 100644 test/unit/gemm/threadblock/mma_multistage_sparse.cu create mode 100644 test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h create mode 100644 test/unit/gemm/threadblock/mma_multistage_testbed.h create mode 100644 test/unit/gemm/threadblock/mma_pipelined_sm80.cu create mode 100644 test/unit/gemm/threadblock/mma_planar_complex_sm80.cu create mode 100644 test/unit/gemm/warp/gemm_sparse_sm80.cu create mode 100644 test/unit/util/CMakeLists.txt delete mode 100644 test/unit/util/complex.cu create mode 100644 test/unit/util/tensor_reduce.cu create mode 100644 tools/library/include/cutlass/library/arch_mappings.h create mode 100644 tools/library/src/reference/gemm.cu create mode 100644 tools/library/src/reference/gemm_reference_operation.h create mode 100644 tools/library/src/reference/initialize_reference_operations.cu create mode 100644 tools/profiler/src/sparse_gemm_operation_profiler.cu create mode 100644 tools/profiler/src/sparse_gemm_operation_profiler.h create mode 100644 tools/util/include/cutlass/util/host_uncompress.h create mode 100644 tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h create mode 100644 tools/util/include/cutlass/util/reference/device/gemm_complex.h create mode 100644 tools/util/include/cutlass/util/reference/device/tensor_reduce.h create mode 100644 tools/util/include/cutlass/util/reference/host/tensor_reduce.h diff --git a/CHANGELOG.md b/CHANGELOG.md index 138161065b..96053eefb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,20 @@ # CUTLASS 2.x +## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23) + * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) + * [Sparse Tensor Core GEMM kernels](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu): + * Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends) + * Fast SGEMM targeting GeForce RTX 30-series CUDA Cores + * Minor Features: + * [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h) + * Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code + * [Floating-point constants](/include/cutlass/constants.h) + * NVIDIA Ampere GPU Architecture examples and documentation: + * [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and + * [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu) + * Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue) + ## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) * Fast Tensor Core operations: diff --git a/CMakeLists.txt b/CMakeLists.txt index b6583747c6..d853a9dd3c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,7 +32,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -project(CUTLASS VERSION 2.2.0 LANGUAGES CXX) +project(CUTLASS VERSION 2.3.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) find_package(Doxygen QUIET) @@ -69,6 +69,8 @@ endif() set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable CUTLASS Examples") set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") +set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Library") +set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_TOOLS} CACHE BOOL "Enable CUTLASS Profiler") if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_TOOLS_INIT}) @@ -101,6 +103,9 @@ endif() if (NOT CUDA_VERSION VERSION_LESS 11.0) list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80) endif() +if (NOT CUDA_VERSION VERSION_LESS 11.1) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86) +endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") @@ -164,12 +169,14 @@ set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code. # set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.") set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.") +set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.") # Test Levels L0, L1, L2 set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}) +list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL}) # # CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations. @@ -181,6 +188,11 @@ else() set(CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT ON) endif() +# Trace levels for debugging +set(CUTLASS_DEBUG_TRACE_LEVEL "0" CACHE STRING "Level of debug tracing to perform.") +list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_TRACE_LEVEL}) + + set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL "Enable PTX mma instruction for collective matrix multiply operations.") @@ -352,7 +364,7 @@ set_target_properties(CUTLASS PROPERTIES EXPORT_NAME cutlass) set(CUTLASS_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library") -set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library/) +set(CUTLASS_GENERATOR_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/library CACHE INTERNAL "Location of generator scripts") # The following utility directory is needed even if the tools build is disabled, so it exists here. set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/include CACHE INTERNAL "") diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f8778b80e6..a4e0a2a435 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -16,6 +16,9 @@ Naila Farooqui Piotr Majcher Paul Springer Jin Wang +Aniket Shivam +Chinmay Talegaonkar +Shang Zhang Scott Yokim Markus Hohnerbach Aditya Atluri @@ -52,6 +55,8 @@ Olivier Giroux Stephen Jones Rishkul Kulkarni Bryce Lelbach +Matthew Nicely Joel McCormack Kyrylo Perelygin + diff --git a/CUDA.cmake b/CUDA.cmake index b8b343a723..c887178a89 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -213,7 +213,14 @@ function(cutlass_correct_source_file_language_property) endif() endfunction() -set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation") +# If building with all kernels, set UNITY build on by default. +if (CUTLASS_LIBRARY_KERNELS MATCHES "all") + set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON) +else() + set(CUTLASS_UNITY_BUILD_ENABLED_INIT OFF) +endif() + +set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation") set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files") function(cutlass_unify_source_files TARGET_ARGS_VAR) diff --git a/README.md b/README.md index b0a91e77c6..88a1b40706 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.2 +# CUTLASS 2.3 -_CUTLASS 2.2 - June 2020_ +_CUTLASS 2.3 - September 2020_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. @@ -30,6 +30,14 @@ See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. See the [functionality listing](media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. +# What's New in CUTLASS 2.3 + +CUTLASS 2.3 is a minor update to CUTLASS adding: +- GEMMs targeting structured [Sparse Tensor Cores](test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) in NVIDIA Ampere Architecture GPUs +- Fast SGEMM kernels targeting GeForce RTX 30-series CUDA Cores +- Intended to be compiled with [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit) +- See the [CHANGELOG](CHANGELOG.md) for more details. + # What's New in CUTLASS 2.2 CUTLASS 2.2 is a significant update to CUTLASS adding: @@ -42,7 +50,7 @@ CUTLASS 2.2 is a significant update to CUTLASS adding: # What's New in CUTLASS 2.1 -CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding: +CUTLASS 2.1 is a minor update to CUTLASS adding: - [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores - BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library) @@ -71,8 +79,8 @@ using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's # Compatibility CUTLASS requires a C++11 host compiler and -performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit). -It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2. +performs best when built with the [CUDA 11.1 Toolkit](https://developer.nvidia.com/cuda-toolkit). +It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, CUDA 10.2, and CUDA 11.0. We have tested the following environments. @@ -99,10 +107,11 @@ any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GP |NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2| |NVIDIA Tesla T4|7.5|10.0|10.2| |NVIDIA A100|8.0|11.0|11.0| +|NVIDIA GeForce 3090|8.6|11.1|11.1| # Documentation -CUTLASS 2.2 is described in the following documents and the accompanying +CUTLASS is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). - [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS @@ -136,14 +145,14 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc ``` Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels -for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify +for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify the architectures to build CUTLASS for by changing the CMake configuration setting `CUTLASS_NVCC_ARCHS`. ``` $ mkdir build && cd build -$ cmake .. -DCUTLASS_NVCC_ARCHS=75 # compiles for NVIDIA's Turing GPU architecture +$ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA's Ampere Architecture ``` From the `build/` directory, compile and run the CUTLASS unit tests by building the target `test_unit` with make. @@ -258,15 +267,25 @@ The `tools/profiler/` directory contains a command-line utility for launching ea It can be built as follows: ``` -$ make cutlass_profiler -j +$ make cutlass_profiler -j16 ``` -To limit compilation time, only one tile size is instantiated for each data type, math instruction, and layout. +By default, only one tile size is instantiated for each data type, math instruction, and layout. To instantiate all, set the following environment variable when running CMake from an empty `build/` directory. +Beware, this results in *thousands* of kernels and long build times. ``` $ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all ... -$ make cutlass_profiler -j +$ make cutlass_profiler -j16 +``` + +To compile strictly one kernel or a small set of kernels, a comma-delimited list of kernel names with +wildcard characters may be reduce the set of kernels. The following builds exactly one kernel: + +``` +$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1 +... +$ make cutlass_profiler -j16 ``` Example command line for profiling SGEMM kernels is as follows: diff --git a/examples/02_dump_reg_shmem/dump_reg_shmem.cu b/examples/02_dump_reg_shmem/dump_reg_shmem.cu index ed712aa84e..9d7db79a95 100644 --- a/examples/02_dump_reg_shmem/dump_reg_shmem.cu +++ b/examples/02_dump_reg_shmem/dump_reg_shmem.cu @@ -69,7 +69,7 @@ template __global__ void kernel_dump(typename GmemIterator::Params params, typename GmemIterator::TensorRef ref) { - __shared__ Element shared_storage[EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL]; + extern __shared__ Element shared_storage[]; // Construct the global iterator and load the data to the fragments. int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; @@ -164,8 +164,11 @@ int main() { dim3 grid(1, 1); dim3 block(32, 1, 1); + int smem_size = + int(sizeof(Element) * EXAMPLE_MATRIX_ROW * EXAMPLE_MATRIX_COL); + kernel_dump - <<>>(params, matrix.device_ref()); + <<>>(params, matrix.device_ref()); cudaError_t result = cudaDeviceSynchronize(); diff --git a/examples/10_planar_complex/planar_complex.cu b/examples/10_planar_complex/planar_complex.cu index b7318b99c2..d810777d9c 100644 --- a/examples/10_planar_complex/planar_complex.cu +++ b/examples/10_planar_complex/planar_complex.cu @@ -50,7 +50,7 @@ To build strictly the planar complex kernels needed for general application, execute the following CMake command in an empty build directory. - $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex This builds all planar complex GEMM variants for Volta and Turing architectures. @@ -59,7 +59,7 @@ specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for the 'CN' layout configuration (conjugate A operand with both A and B as column-major). - $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_f16*cn $ make 10_planar_complex @@ -526,6 +526,11 @@ int main(int argc, char const **args) { return 0; } } + else { + // NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond. + // + // fall through + } // // Parse options diff --git a/examples/11_planar_complex_array/planar_complex_array.cu b/examples/11_planar_complex_array/planar_complex_array.cu index 6a0270533e..53134168a0 100644 --- a/examples/11_planar_complex_array/planar_complex_array.cu +++ b/examples/11_planar_complex_array/planar_complex_array.cu @@ -48,7 +48,7 @@ To build strictly the planar complex kernels needed for general application, execute the following CMake command in an empty build directory. - $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_*gemm_planar_complex This builds all planar complex GEMM variants for Volta and Turing architectures. @@ -57,7 +57,7 @@ specified as follows. This only builds planar complex GEMMs targeting Tensor Cores for the 'CN' layout configuration (conjugate A operand with both A and B as column-major). - $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75" \ + $ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" \ -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_f16_s*gemm_planar_complex_array_f16*cn $ make 11_planar_complex_array @@ -586,6 +586,11 @@ int main(int argc, char const **args) { return 0; } } + else { + // NVIDIA Ampere Architecture GPUs (SM80 and later) are fully supported on CUDA 11 Toolkit and beyond. + // + // fall through + } // // Parse options diff --git a/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h b/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h new file mode 100644 index 0000000000..32b77128e8 --- /dev/null +++ b/examples/13_fused_two_gemms/b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h @@ -0,0 +1,205 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/b2b_gemm.h" +#include "b2b_interleaved_gemm_run.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +void run_nonfused_gemm_s8_sm80() { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576); + cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64); + ElementCompute alpha0 = ElementCompute(2); + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(2); + ElementCompute beta1 = ElementCompute(0); + + using ThreadblockShape0 = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + using Gemm0 = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape0, + WarpShape0, + InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + using Gemm1 = cutlass::gemm::device::Gemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape1, + WarpShape1, + InstructionShape, + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + B2bInterleavedNonFusedGemmRun nonFusedGemm; + + std::cout << "Running Non-fused back-to-back INT8 NT interleaved GEMMs...\n"; + bool pass = nonFusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1); + if(pass) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; +} + +void run_fused_gemm_s8_sm80() { + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + cutlass::gemm::GemmCoord problem_size_0(128*1600, 64, 576); + cutlass::gemm::GemmCoord problem_size_1(128*1600, 128, 64); + ElementCompute alpha0 = ElementCompute(2); + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(2); + ElementCompute beta1 = ElementCompute(0); + + using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<32, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape1 = cutlass::gemm::GemmShape<32, 128, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + using EpilogueOutputOp0 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 8 * InstructionShape::kN / 32, + ElementAccumulator, + ElementCompute + >; + + using EpilogueOutputOp1 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >; + + + + using B2bGemm = cutlass::gemm::device::B2bGemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + 16, + 16, + false, + cutlass::arch::OpMultiplyAddSaturate, + true + >; + + B2bInterleavedFusedGemmRun fusedGemm; + + std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs...\n"; + bool passed = fusedGemm.run(problem_size_0, problem_size_1, alpha0, beta0, alpha1, beta1); + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h b/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h index 906cabb409..e98be9e511 100644 --- a/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h +++ b/examples/13_fused_two_gemms/b2b_interleaved_gemm_run.h @@ -38,6 +38,8 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/host_reorder.h" #include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + #include "helper.h" #define CHECK_GT(val1, val2) \ @@ -115,7 +117,9 @@ struct B2bInterleavedNonFusedGemmRun ElementCompute beta0 = ElementCompute(0), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), - bool relu = true) { + bool relu = true, + int warm_ups = 1, + int runs = 100) { // // Allocate the GEMM workspace @@ -232,6 +236,13 @@ struct B2bInterleavedNonFusedGemmRun status = gemm_op_1.initialize(arguments_1); CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } // // Run the GEMM // @@ -242,14 +253,14 @@ struct B2bInterleavedNonFusedGemmRun cudaEventRecord(start); - for(int i = 0; i < 100; i++) { + for(int i = 0; i < runs; i++) { status = gemm_op_0(); CUTLASS_CHECK(status); } cudaEventRecord(stop1); - for(int i = 0; i < 100; i++) { + for(int i = 0; i < runs; i++) { status = gemm_op_1(); CUTLASS_CHECK(status); @@ -261,9 +272,9 @@ struct B2bInterleavedNonFusedGemmRun cudaEventElapsedTime(&gemm0Time, start, stop1); cudaEventElapsedTime(&gemm1Time, stop1, stop2); cudaEventElapsedTime(&totalTime, start, stop2); - std::cout << "gemm 0 time " << gemm0Time / 100.0 << " ms\n"; - std::cout << "gemm 1 time " << gemm1Time / 100.0 << " ms\n"; - std::cout << "total time " << totalTime / 100.0 << " ms\n"; + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "total time " << totalTime / (float)runs << " ms\n"; tensor_D0.sync_host(); tensor_D1.sync_host(); @@ -302,7 +313,7 @@ struct B2bInterleavedNonFusedGemmRun reference_gemm_1( problem_size_1, alpha1, - tensor_D0.device_ref(), + reference_D0.device_ref(), tensor_B1.device_ref(), beta1, tensor_C1.device_ref(), @@ -420,7 +431,9 @@ struct B2bInterleavedFusedGemmRun ElementCompute beta0 = ElementCompute(0), ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), - bool relu = true) { + bool relu = true, + int warm_ups = 1, + int runs = 100) { // // Allocate the GEMM workspace @@ -478,7 +491,7 @@ struct B2bInterleavedFusedGemmRun CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); //Reorder B0 - cutlass::reorder_column( + cutlass::reorder_column<16>( tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); cutlass::reorder_column( tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); @@ -526,6 +539,11 @@ struct B2bInterleavedFusedGemmRun CUTLASS_CHECK(status); + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + // // Run the GEMM // @@ -536,7 +554,7 @@ struct B2bInterleavedFusedGemmRun cudaEventRecord(start); - for(int i = 0; i < 100; i++) { + for(int i = 0; i < runs; i++) { status = b2b_gemm_op(); CUTLASS_CHECK(status); @@ -546,7 +564,7 @@ struct B2bInterleavedFusedGemmRun cudaDeviceSynchronize(); float gemmTime; cudaEventElapsedTime(&gemmTime, start, stop); - std::cout << "time " << gemmTime / 100.0 << " ms\n"; + std::cout << "time " << gemmTime / (float)runs << " ms\n"; //tensor_D0.sync_host(); tensor_D1.sync_host(); diff --git a/examples/13_fused_two_gemms/fused_gemm.cu b/examples/13_fused_two_gemms/fused_gemm.cu index a7856abe5a..edc08d3189 100644 --- a/examples/13_fused_two_gemms/fused_gemm.cu +++ b/examples/13_fused_two_gemms/fused_gemm.cu @@ -30,7 +30,6 @@ two unfused GEMM operations, demonstrating a speedup of the fused kernel on the NVIDIA Turing GPU architecture. Problem size: - GEMM1 (M,N,K): 128*1600, 64, 576 GEMM2 (M,N,K): 128*1600, 128, 64 @@ -42,16 +41,17 @@ also requires warp_tile_N = thread_block_tile_N so the data required by each war register-file-resident. Performance: - - fp16 on Tesla T4 @ 1590MHz (non-fused vs. fused): 1.39011 ms vs. 1.26035 ms - int8 on Tesla T4 @ 1590MHz (non-fused vs. fused): 0.751759 ms vs. 0.62971 ms - fp16 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.721144 ms vs. 0.629864 ms - int8 on Quadro RTX 8000 @ 1890MHz (non-fused vs. fused): 0.379049 ms vs. 0.324764 ms + - int8 on GA100 @ 1200MHz (non-fused vs. fused): 0.153795 ms vs. 0.129874 ms */ #include "b2b_gemm_f16t_f16n_f16t_tensor_op_f16_sm75.h" #include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm75.h" +#include "b2b_gemm_s8n_s8t_s8n_tensor_op_s32_sm80.h" int run() { @@ -71,7 +71,10 @@ int run() { return 0; } -#if defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + run_nonfused_gemm_s8_sm80(); + run_fused_gemm_s8_sm80(); +#elif defined(CUTLASS_ARCH_MMA_SM75_SUPPORTED) run_nonfused_gemm_f16(); run_fused_gemm_f16(); run_nonfused_gemm_s8(); diff --git a/examples/13_fused_two_gemms/kernel/b2b_gemm.h b/examples/13_fused_two_gemms/kernel/b2b_gemm.h index d106fa46af..5df5e4e38d 100644 --- a/examples/13_fused_two_gemms/kernel/b2b_gemm.h +++ b/examples/13_fused_two_gemms/kernel/b2b_gemm.h @@ -210,7 +210,8 @@ struct B2bGemm { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -313,7 +314,8 @@ struct B2bGemm { // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h b/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h index 45b2d545ef..dab9db904c 100644 --- a/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h +++ b/examples/13_fused_two_gemms/kernel/default_b2b_gemm.h @@ -217,7 +217,85 @@ struct DefaultB2bGemm< }; -/// Partial specialization for Turing IMMA Interleaved layout +/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Number of Interleaved k + int InterleavedK, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Is Beta zero or not + bool IsBetaZero> +struct DefaultB2bGemm< + ElementA, layout::ColumnMajorInterleaved, kAlignmentA, + ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, int32_t, + arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, Stages, + SplitKSerial, Operator, IsBetaZero> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0, + true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK, + IsBetaZero>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for Turing Integer Tensor Core Interleaved layout template < /// Element type for A matrix operand typename ElementA, diff --git a/examples/13_fused_two_gemms/threadblock/b2b_mma_multistage.h b/examples/13_fused_two_gemms/threadblock/b2b_mma_multistage.h new file mode 100644 index 0000000000..8782b7af55 --- /dev/null +++ b/examples/13_fused_two_gemms/threadblock/b2b_mma_multistage.h @@ -0,0 +1,862 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaMultistage : + public B2bMmaBase { +public: + ///< Base class + using Base = B2bMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over intermediate accumulator tile + using FragmentIteratorA1 = FragmentIteratorA1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const TBLDGSTSIterationsA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLDGSTSIterationsB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLDGSTSIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.sharedStorage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.sharedStorage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + iterator_A0.set_iteration_index(group_start_A0 * + IteratorA0::kAccessesPerVector); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // LDGSTS for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0 * + IteratorB0::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index(group_start_B1 * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B operand in global memory + IteratorB0 iterator_B0, + ///< iterator over B operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0) + { + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + if (gemm_k_iterations_0 == 0) { + iterator_A0.clear_mask(); + iterator_B0.clear_mask(); + } + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // LDGSTS for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (gemm_k_iterations_0 == 0) { + iterator_A0.clear_mask(); + iterator_B0.clear_mask(); + } + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { + int group_start_iteration_A0, group_start_iteration_B0; + + group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + int group_start_iteration_A0, group_start_iteration_B0; + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + if (gemm_k_iterations_0 == 0) { + iterator_A0.clear_mask(); + iterator_B0.clear_mask(); + } + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + } + + } + + + // 2nd Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + + // + // Prologue + // + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + if (gemm_k_iterations_1 == 0) { +// iterator_A1.clear_mask(); + iterator_B1.clear_mask(); + } + +#if 0 + iterator_A1.set_iteration_index(0); + this->smem_iterator_A1_.set_iteration_index(0); + + // LDGSTS for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsA1; ++j) { + typename IteratorA1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA1::ThreadMap::kElementsPerAccess / + IteratorA1::kAccessesPerVector / 8; + + int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } +#endif + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // LDGSTS for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + //iterator_A1.add_tile_offset({0, 1}); + iterator_B1.add_tile_offset({1, 0}); + + //this->smem_iterator_A1_.add_tile_offset({0, 1}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand +// FragmentC0 accum0 = src_accum; + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + +// this->warp_tile_iterator_A1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], output_op_0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (gemm_k_iterations_1 == 0) { +// iterator_A1.clear_mask(); + iterator_B1.clear_mask(); + } + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + +// this->warp_tile_iterator_A1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], output_op_0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + +// --gemm_k_iterations_1; + if (gemm_k_iterations_1 == 1) { + iterator_B1.clear_mask(); + } + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + + } + + + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h b/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h index ca89cf0bdc..9887932a37 100644 --- a/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h +++ b/examples/13_fused_two_gemms/threadblock/b2b_mma_pipelined.h @@ -48,10 +48,6 @@ namespace gemm { namespace threadblock { //////////////////////////////////////////////////////////////////////////////////////////////// -template -struct chk_val { - static_assert(a==0, "check value"); -}; /// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. template < diff --git a/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h b/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h index cd1403c792..b3621f56e6 100644 --- a/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h +++ b/examples/13_fused_two_gemms/threadblock/default_b2b_mma.h @@ -40,6 +40,7 @@ #include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" #include "threadblock/b2b_mma_pipelined.h" +#include "threadblock/b2b_mma_multistage.h" //////////////////////////////////////////////////////////////////////////////// @@ -200,8 +201,6 @@ template < typename ElementAccumulator, /// Tag indicating architecture to tune for typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, /// Threadblock-level tile size (concept: GemmShape) typename ThreadblockShape0, /// Threadblock-level tile size (concept: GemmShape) @@ -220,7 +219,7 @@ template < int InterleavedK> struct DefaultB2bMma, OperatorClass, ArchTag, + layout::ColumnMajorInterleaved, OperatorClass, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp, true> { // Define the MmaCore components @@ -251,7 +250,7 @@ struct DefaultB2bMma, ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; - // Use fragment iterator for A operand + // Use fragment iterator for A1 operand using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true using FragmentIteratorA1 = cutlass::gemm::warp::MmaTensorOpFragmentIterator< @@ -282,6 +281,111 @@ struct DefaultB2bMma +struct DefaultB2bMma, OperatorClass, ArchTag, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; + + // Use fragment iterator for A1 operand + using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, + InstructionShape, EpilogueOutputOp, true /*only handle beta=0 for 1st Gemm epilogue*/>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; + + + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + + } // namespace threadblock } // namespace gemm } // namespace cutlass diff --git a/examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt b/examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt new file mode 100644 index 0000000000..49e1a4f9e3 --- /dev/null +++ b/examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 14_ampere_tf32_tensorop_gemm + ampere_tf32_tensorop_gemm.cu + ) + diff --git a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu new file mode 100644 index 0000000000..2533557134 --- /dev/null +++ b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +Please check example 07 and 08 for the basics of tensor op gemm kernels. On NVIDIA Ampere +architecture, most concept still holds. The two main differences are + +1. NVIDIA Ampere architecture introduces a new series of tensor core instructions (see + include/cutlass/arch/mma_sm80.h) which are more efficient on Ampere. + +2. NVIDIA Ampere architecture uses cp_async() to build multistage software pipeline to better hide + latency (see include/cutlass/gemm/threadblock/mma_multistage.h) + +Moreover, NVIDIA Ampere architecture starts supporting tfloat32 (see include/cutlass/tfloat32.h) +data types in tensor cores. One big advantage is that we can load in fp32 data and convert them +implicitly to tf32 inside the GEMM kernel which means no change is needed to accelerate traditional +fp32 data by using NVIDIA Ampere architecture. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "helper.h" + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using ElementAccumulator = float; // <- data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations +using ElementInputA = float; // <- data type of elements in input matrix A +using ElementInputB = float; // <- data type of elements in input matrix B +using ElementOutput = float; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Column Major for +// Matrix A, Row Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + +// This code section describes the epilogue part of the kernel +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + +// Number of pipelines you want to use +constexpr int NumStages = 4; + +using Gemm = cutlass::gemm::device::Gemm; + +int run() { + + // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + return -1; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + + // Return 0 so tests are considered passing if run on unsupported platforms. + return 0; + } + + const int length_m = 5120; + const int length_n = 4096; + const int length_k = 4096; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + problem_size.mk()); // <- Create matrix A with dimensions M x K + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create matrix B with dimensions K x N + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(4), + ElementInputA(-4), + 0); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(4), + ElementInputB(-4), + 0); // <- Fill matrix B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(4), + ElementOutput(-4), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + tensor_a.device_ref(), // <- reference to matrix A on device + tensor_b.device_ref(), // <- reference to matrix B on device + tensor_c.device_ref(), // <- reference to matrix C on device + tensor_d.device_ref(), // <- reference to matrix D on device + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + // Create instantiation for device reference gemm kernel + cutlass::reference::device::Gemm + gemm_device; + + // Launch device reference gemm kernel + gemm_device(problem_size, + alpha, + tensor_a.device_ref(), + tensor_b.device_ref(), + beta, + tensor_c.device_ref(), + tensor_ref_d.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + + // Returning zero so this test passes when built on older Toolkits. + return 0; + } + else { + return run(); + } +} diff --git a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt new file mode 100644 index 0000000000..2d0929c3a8 --- /dev/null +++ b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt @@ -0,0 +1,27 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 15_ampere_sparse_tensorop_gemm + ampere_sparse_tensorop_gemm.cu + ) + diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu new file mode 100644 index 0000000000..02f65b199e --- /dev/null +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere +architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. + +Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of +meta data is different for every data types. CUTLASS templates can automatically infer it based on +input A and B. Check code below. + +Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers +efficiently. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "helper.h" + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using ElementAccumulator = int32_t; // <- data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations +using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A +using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B +using ElementOutput = int32_t; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Column Major for +// Matrix A, Row Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<256, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + +// This code section describes the epilogue part of the kernel +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +using Gemm = cutlass::gemm::device::SparseGemm; + +// Data type and layout of meta data matrix E can be inferred from template Gemm. +using ElementInputE = typename Gemm::ElementE; +using LayoutInputE = typename Gemm::LayoutE; + +// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h +// 50% Sparsity on Ampere +constexpr int kSparse = Gemm::kSparse; +// How many elements of A are covered per ElementE +constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; +// The size of individual meta data +constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + +int run() { + + // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.1. + // + // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; + return -1; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::cerr << "Turing Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + + // Return 0 so tests are considered passing if run on unsupported platforms. + return 0; + } + + const int length_m = 512; + const int length_n = 512; + const int length_k = 1024; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2) + cutlass::HostTensor tensor_a_uncompressed( + problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing + + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create matrix B with dimensions K x N + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing. + cutlass::HostTensor tensor_e( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + // Same size as the above. The above one needs to be reordered and stored in this one. + cutlass::HostTensor tensor_e_reordered( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(1), + ElementInputA(-1), + 0); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(1), + ElementInputB(-1), + 0); // <- Fill matrix B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(1), + ElementOutput(-1), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_e.host_view(), + 1, + kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core + // instructions. + cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_e_reordered.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + tensor_a.device_ref(), // <- reference to matrix A on device + tensor_b.device_ref(), // <- reference to matrix B on device + tensor_c.device_ref(), // <- reference to matrix C on device + tensor_d.device_ref(), // <- reference to matrix D on device + tensor_e.device_ref(), // <- reference to matrix E on device + {alpha, beta}, // <- tuple of alpha and beta + split_k_slices}; // <- k-dimension split factor + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. + cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), + tensor_e.host_ref(), problem_size.m(), problem_size.k()); + + // Create instantiation for host reference gemm kernel + cutlass::reference::host::Gemm + gemm_host; + + // Launch host reference gemm kernel + gemm_host(problem_size, + alpha, + tensor_a_uncompressed.host_ref(), + tensor_b.host_ref(), + beta, + tensor_c.host_ref(), + tensor_ref_d.host_ref()); + + // Copy output data from CUTLASS host for comparison + tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.1. + // + // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; + + // Returning zero so this test passes when built on older Toolkits. + return 0; + } + else { + return run(); + } +} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 99379fe45a..aabfa53c62 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -71,6 +71,8 @@ foreach(EXAMPLE 11_planar_complex_array 12_gemm_bias_relu 13_fused_two_gemms + 14_ampere_tf32_tensorop_gemm + 15_ampere_sparse_tensorop_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index faf01cc656..eb0a2ad43b 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -55,6 +55,17 @@ struct Sm75 { struct Sm80 { static int const kMinComputeCapability = 80; }; +struct Sm86 { + static int const kMinComputeCapability = 86; +}; + +/// Triggers a breakpoint on the device +CUTLASS_DEVICE +void device_breakpoint() { +#if defined(__CUDA_ARCH__) + asm volatile (" brkpt;\n"); +#endif +} //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index 48ef02cd0e..d9f386eec7 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -51,6 +51,8 @@ struct global_load; ///////////////////////////////////////////////////////////////////////////////////////////////// +// The redundant mov PTX instruction is used to enforce the compiler to +// initialize data to zero before ld.global template struct global_load struct global_load, 1, ElementA, LayoutA, ElementB, LayoutB, El ///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specifies internal data type for computation +struct SPFormatType { + enum Kind { + Thread + }; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Number of threads participating + int kThreads_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Inner product operator + typename Operator, + /// Specifies meta data format + SPFormatType::Kind SPFormat = SPFormatType::Thread +> +struct SparseMma; + } // namespace arch } // namespace cutlass @@ -165,4 +201,5 @@ struct Mma, 1, ElementA, LayoutA, ElementB, LayoutB, El #include "cutlass/arch/mma_sm70.h" #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/sp_mma_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sm50.h b/include/cutlass/arch/mma_sm50.h index fce521dcee..cc4a94b17e 100644 --- a/include/cutlass/arch/mma_sm50.h +++ b/include/cutlass/arch/mma_sm50.h @@ -53,6 +53,7 @@ template < struct Mma, 1, float, LayoutA, float, LayoutB, float, LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; CUTLASS_HOST_DEVICE void operator()( @@ -79,6 +80,7 @@ template < struct Mma, 1, double, LayoutA, double, LayoutB, double, LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; CUTLASS_HOST_DEVICE void operator()( @@ -106,6 +108,7 @@ template < struct Mma, 1, int, LayoutA, int, LayoutB, int, LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAdd; CUTLASS_HOST_DEVICE void operator()( @@ -142,6 +145,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; CUTLASS_HOST_DEVICE void operator()( @@ -181,6 +185,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; CUTLASS_HOST_DEVICE void operator()( @@ -218,6 +223,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; CUTLASS_HOST_DEVICE void operator()( @@ -255,6 +261,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; CUTLASS_HOST_DEVICE void operator()( @@ -292,6 +299,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; CUTLASS_HOST_DEVICE void operator()( @@ -327,6 +335,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; + using Operator = OpMultiplyAddComplex; CUTLASS_HOST_DEVICE void operator()( @@ -355,7 +364,8 @@ template < struct Mma, 1, half_t, LayoutA, half_t, LayoutB, float, LayoutC, OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 1>; - + using Operator = OpMultiplyAdd; + CUTLASS_HOST_DEVICE void operator()( Array &d, diff --git a/include/cutlass/arch/mma_sm60.h b/include/cutlass/arch/mma_sm60.h index ab0481ae44..5c82f74ec3 100644 --- a/include/cutlass/arch/mma_sm60.h +++ b/include/cutlass/arch/mma_sm60.h @@ -55,6 +55,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<2, 1, 1>; + using Operator = OpMultiplyAdd; CUTLASS_HOST_DEVICE void operator()( @@ -99,6 +100,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 2, 1>; + using Operator = OpMultiplyAdd; CUTLASS_HOST_DEVICE void operator()( @@ -143,6 +145,7 @@ struct Mma < OpMultiplyAdd> { using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = OpMultiplyAdd; CUTLASS_HOST_DEVICE void operator()( @@ -196,7 +199,8 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<2, 2, 1>; - + using Operator = OpMultiplyAdd; + CUTLASS_HOST_DEVICE void operator()( Array &d, diff --git a/include/cutlass/arch/mma_sm61.h b/include/cutlass/arch/mma_sm61.h index 9ec8857e8c..6cbe260633 100644 --- a/include/cutlass/arch/mma_sm61.h +++ b/include/cutlass/arch/mma_sm61.h @@ -51,7 +51,8 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 4>; - + using Operator = OpMultiplyAdd; + CUTLASS_HOST_DEVICE void operator()( Array &d, @@ -98,6 +99,7 @@ struct Mma< OpMultiplyAdd> { using Shape = gemm::GemmShape<1, 1, 2>; + using Operator = OpMultiplyAdd; CUTLASS_HOST_DEVICE void operator()( diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index d75aa1336c..289c205cad 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -723,7 +723,6 @@ struct Mma< } }; - //////////////////////////////////////////////////////////////////////////////// // // Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE diff --git a/include/cutlass/arch/simd.h b/include/cutlass/arch/simd.h index 4520acc9b2..2503094ad3 100644 --- a/include/cutlass/arch/simd.h +++ b/include/cutlass/arch/simd.h @@ -85,7 +85,7 @@ Array mac(Array const &a, Array const &b, Array const &c Array d; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N; ++i) { - d[i] = a[i] * b[i] + c; + d[i] = a[i] * b[i] + c[i]; } return d; } diff --git a/include/cutlass/arch/sp_mma_sm80.h b/include/cutlass/arch/sp_mma_sm80.h new file mode 100644 index 0000000000..0c8989b86a --- /dev/null +++ b/include/cutlass/arch/sp_mma_sm80.h @@ -0,0 +1,1591 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse matrix multiply accumulate for SM80 +*/ + +#pragma once + +#include "mma_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1)) + +#define CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED 1 + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +#define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED +#endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16832 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F16 = F16 * F16 + F16 +template <> +struct SparseMma< + gemm::GemmShape<16, 8, 32>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + half_t, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread +> { + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = half_t; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + uint32_t const *C = reinterpret_cast(&c); + uint32_t *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else { + assert(0); + } +#else + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = F16 * F16 + F32 +template <> +struct SparseMma< + gemm::GemmShape<16, 8, 32>, + 32, + half_t, + layout::RowMajor, + half_t, + layout::ColumnMajor, + float, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread + > { + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = half_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = half_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else { + assert(0); + } + +#else + + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16832 - Float BF16, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = bf16 * bf16 + F32 +template <> +struct SparseMma, 32, bfloat16_t, layout::RowMajor, + bfloat16_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd, SPFormatType::Thread> { + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = bfloat16_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = bfloat16_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 2; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } + +#else + + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16816 - Float TF32 +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = tf32 * tf32 + F32 +template <> +struct SparseMma, 32, tfloat32_t, layout::RowMajor, + tfloat32_t, layout::ColumnMajor, float, layout::RowMajor, + OpMultiplyAdd, SPFormatType::Thread> { + using Shape = gemm::GemmShape<16, 8, 16>; + + using ElementA = tfloat32_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = tfloat32_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 4; + + static int const kMaxID2 = 2; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c, uint32_t const &E, int const id2) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); + } + +#else + + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + int8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = int8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + int8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = int8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + uint8_t, + layout::RowMajor, + uint8_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = uint8_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = uint8_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAdd, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// +// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::int4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::int4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::int4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::int4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +template <> +struct SparseMma< + gemm::GemmShape<16,8,128>, + 32, + cutlass::uint4b_t, + layout::RowMajor, + cutlass::uint4b_t, + layout::ColumnMajor, + int, + layout::RowMajor, + OpMultiplyAddSaturate, + SPFormatType::Thread> { + + using Shape = gemm::GemmShape<16,8,128>; + + using ElementA = cutlass::uint4b_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::uint4b_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = int; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = OpMultiplyAdd; + using ArchTag = arch::Sm80; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); + + if (id2 == 0) + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + else + assert(0); + +#else + + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 0018b76f5a..3faa11d022 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -487,6 +487,46 @@ class Array { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x) { + Array m; + m[0] = x; + return m; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y) { + Array m; + m[0] = x; + m[1] = y; + return m; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y, Element z) { + Array m; + m[0] = x; + m[1] = y; + m[2] = z; + return m; +} + +template +CUTLASS_HOST_DEVICE +Array make_Array(Element x, Element y, Element z, Element w) { + Array m; + m[0] = x; + m[1] = y; + m[2] = z; + m[3] = w; + return m; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index c3bd1782bb..3a4b8bd76e 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -65,7 +65,7 @@ struct alignas(2) bfloat16_t { /// Default constructor CUTLASS_HOST_DEVICE - bfloat16_t() { } + bfloat16_t() : storage(0) { } /// Floating-point conversion - round toward nearest CUTLASS_HOST_DEVICE diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 6f7d73bb91..7c0ab3b4f3 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -187,10 +187,12 @@ class complex /// Division template CUTLASS_HOST_DEVICE complex operator/(complex const &rhs) const { - T d = (rhs.real() * (rhs) + rhs.imag() * rhs.imag()); + T d = T(rhs.real() * rhs.real() + rhs.imag() * rhs.imag()); - return complex((this->real() * (rhs) + this->imag() * rhs.imag()) / d, - (this->imag() * (rhs)-this->real() * rhs.imag()) / d); + return complex( + (real() * rhs.real() + imag() * rhs.imag()) / d, + (imag() * rhs.real() - real() * rhs.imag()) / d + ); } /// Scalar Division diff --git a/include/cutlass/constants.h b/include/cutlass/constants.h new file mode 100644 index 0000000000..690891b227 --- /dev/null +++ b/include/cutlass/constants.h @@ -0,0 +1,1233 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief Boost-style constant definitions for floating-point types. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/complex.h" + +/////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace constants { + +/////////////////////////////////////////////////////////////////////////////////// + +// +// Primary templates +// + +/// Returns 1, the multiplicative identity element +template CUTLASS_HOST_DEVICE T one(); + +/// Returns 0, the additive identity element +template CUTLASS_HOST_DEVICE T zero(); + +/// Returns 2 +template CUTLASS_HOST_DEVICE T two(); + +/// Returns pi, approximately 3.141 +template CUTLASS_HOST_DEVICE T pi(); + +/// Returns 2 * pi +template CUTLASS_HOST_DEVICE T two_pi(); + +/// Returns pi / 2 +template CUTLASS_HOST_DEVICE T half_pi(); + +/// Returns sqrt(pi) +template CUTLASS_HOST_DEVICE T root_pi(); + +/// Returns sqrt(pi / 2) +template CUTLASS_HOST_DEVICE T root_half_pi(); + +/// Returns sqrt(2 * pi) +template CUTLASS_HOST_DEVICE T root_two_pi(); + +/// Returns sqrt(ln(4)) +template CUTLASS_HOST_DEVICE T root_ln_four(); + +/// Returns e, approximately 2.718... +template CUTLASS_HOST_DEVICE T e(); + +/// Returns (1/2) +template CUTLASS_HOST_DEVICE T half(); + +/// Returns sqrt(2), approximately 1.414... +template CUTLASS_HOST_DEVICE T root_two(); + +/// Returns sqrt(2)/2, approximately 0.707... +template CUTLASS_HOST_DEVICE T half_root_two(); + +/// Returns ln(2), approximately 0.693... +template CUTLASS_HOST_DEVICE T ln_two(); + +/// Returns ln(ln(2)), approximately -0.3665... +template CUTLASS_HOST_DEVICE T ln_ln_two(); + +/// Returns 1/3, approximately 0.333... +template CUTLASS_HOST_DEVICE T third(); + +/// Returns 2/3, approximately 0.666... +template CUTLASS_HOST_DEVICE T twothirds(); + +/// Returns pi - 3, approximately 0.1416... +template CUTLASS_HOST_DEVICE T pi_minus_three(); + +/// Returns 4 - pi, approximately 0.858... +template CUTLASS_HOST_DEVICE T four_minus_pi(); + + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for double + +/// Returns 1, the multiplicative identity element (specialization for double) +template <> CUTLASS_HOST_DEVICE double one() { + uint64_t bits = 0x3ff0000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), double()); +} + +/// Returns 0, the additive identity element (specialization for double) +template <> CUTLASS_HOST_DEVICE double zero() { + uint64_t bits = 0x0ull; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), double()); +} + +/// Returns 2 (specialization for double) +template <> CUTLASS_HOST_DEVICE double two() { + uint64_t bits = 0x4000000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), double()); +} + +/// Returns pi, approximately 3.141 (specialization for double) +template <> CUTLASS_HOST_DEVICE double pi() { + uint64_t bits = 0x400921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), double()); +} + +/// Returns 2 * pi (specialization for double) +template <> CUTLASS_HOST_DEVICE double two_pi() { + uint64_t bits = 0x401921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), double()); +} + +/// Returns pi / 2 (specialization for double) +template <> CUTLASS_HOST_DEVICE double half_pi() { + uint64_t bits = 0x3ff921fb54442d18ull; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), double()); +} + +/// Returns sqrt(pi) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_pi() { + uint64_t bits = 0x3ffc5bf891b4ef6aull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), double()); +} + +/// Returns sqrt(pi / 2) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_half_pi() { + uint64_t bits = 0x3ff40d931ff62705ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), double()); +} + +/// Returns sqrt(2 * pi) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_two_pi() { + uint64_t bits = 0x40040d931ff62705ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), double()); +} + +/// Returns sqrt(ln(4)) (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_ln_four() { + uint64_t bits = 0x3ff2d6abe44afc43ull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), double()); +} + +/// Returns e, approximately 2.718... (specialization for double) +template <> CUTLASS_HOST_DEVICE double e() { + uint64_t bits = 0x4005bf0a8b145769ull; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), double()); +} + +/// Returns (1/2) (specialization for double) +template <> CUTLASS_HOST_DEVICE double half() { + uint64_t bits = 0x3fe0000000000000ull; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), double()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for double) +template <> CUTLASS_HOST_DEVICE double root_two() { + uint64_t bits = 0x3ff6a09e667f3bcdull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), double()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for double) +template <> CUTLASS_HOST_DEVICE double half_root_two() { + uint64_t bits = 0x3fe6a09e667f3bcdull; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), double()); +} + +/// Returns ln(2), approximately 0.693... (specialization for double) +template <> CUTLASS_HOST_DEVICE double ln_two() { + uint64_t bits = 0x3fe62e42fefa39efull; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), double()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for double) +template <> CUTLASS_HOST_DEVICE double ln_ln_two() { + uint64_t bits = 0xbfd774f29bdd6b9full; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), double()); +} + +/// Returns 1/3, approximately 0.333... (specialization for double) +template <> CUTLASS_HOST_DEVICE double third() { + uint64_t bits = 0x3fd5555555555555ull; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), double()); +} + +/// Returns 2/3, approximately 0.666... (specialization for double) +template <> CUTLASS_HOST_DEVICE double twothirds() { + uint64_t bits = 0x3fe5555555555555ull; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), double()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for double) +template <> CUTLASS_HOST_DEVICE double pi_minus_three() { + uint64_t bits = 0x3fc21fb54442d180ull; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), double()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for double) +template <> CUTLASS_HOST_DEVICE double four_minus_pi() { + uint64_t bits = 0x3feb7812aeef4ba0ull; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), double()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for float + +/// Returns 1, the multiplicative identity element (specialization for float) +template <> CUTLASS_HOST_DEVICE float one() { + uint32_t bits = 0x3f800000u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), float()); +} + +/// Returns 0, the additive identity element (specialization for float) +template <> CUTLASS_HOST_DEVICE float zero() { + uint32_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), float()); +} + +/// Returns 2 (specialization for float) +template <> CUTLASS_HOST_DEVICE float two() { + uint32_t bits = 0x40000000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), float()); +} + +/// Returns pi, approximately 3.141 (specialization for float) +template <> CUTLASS_HOST_DEVICE float pi() { + uint32_t bits = 0x40490fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), float()); +} + +/// Returns 2 * pi (specialization for float) +template <> CUTLASS_HOST_DEVICE float two_pi() { + uint32_t bits = 0x40c90fdbu; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), float()); +} + +/// Returns pi / 2 (specialization for float) +template <> CUTLASS_HOST_DEVICE float half_pi() { + uint32_t bits = 0x3fc90fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), float()); +} + +/// Returns sqrt(pi) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_pi() { + uint32_t bits = 0x3fe2dfc5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), float()); +} + +/// Returns sqrt(pi / 2) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_half_pi() { + uint32_t bits = 0x3fa06c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), float()); +} + +/// Returns sqrt(2 * pi) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_two_pi() { + uint32_t bits = 0x40206c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), float()); +} + +/// Returns sqrt(ln(4)) (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_ln_four() { + uint32_t bits = 0x3f96b55fu; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), float()); +} + +/// Returns e, approximately 2.718... (specialization for float) +template <> CUTLASS_HOST_DEVICE float e() { + uint32_t bits = 0x402df854u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), float()); +} + +/// Returns (1/2) (specialization for float) +template <> CUTLASS_HOST_DEVICE float half() { + uint32_t bits = 0x3f000000u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), float()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for float) +template <> CUTLASS_HOST_DEVICE float root_two() { + uint32_t bits = 0x3fb504f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), float()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for float) +template <> CUTLASS_HOST_DEVICE float half_root_two() { + uint32_t bits = 0x3f3504f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), float()); +} + +/// Returns ln(2), approximately 0.693... (specialization for float) +template <> CUTLASS_HOST_DEVICE float ln_two() { + uint32_t bits = 0x3f317218u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), float()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for float) +template <> CUTLASS_HOST_DEVICE float ln_ln_two() { + uint32_t bits = 0xbebba795u; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), float()); +} + +/// Returns 1/3, approximately 0.333... (specialization for float) +template <> CUTLASS_HOST_DEVICE float third() { + uint32_t bits = 0x3eaaaaabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), float()); +} + +/// Returns 2/3, approximately 0.666... (specialization for float) +template <> CUTLASS_HOST_DEVICE float twothirds() { + uint32_t bits = 0x3f2aaaabu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), float()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for float) +template <> CUTLASS_HOST_DEVICE float pi_minus_three() { + uint32_t bits = 0x3e10fdaau; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), float()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for float) +template <> CUTLASS_HOST_DEVICE float four_minus_pi() { + uint32_t bits = 0x3f5bc095u; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), float()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for tfloat32_t + +/// Returns 1, the multiplicative identity element (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t one() { + uint32_t bits = 0x3f801000u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), tfloat32_t()); +} + +/// Returns 0, the additive identity element (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t zero() { + uint32_t bits = 0x1000u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), tfloat32_t()); +} + +/// Returns 2 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t two() { + uint32_t bits = 0x40001000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), tfloat32_t()); +} + +/// Returns pi, approximately 3.141 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t pi() { + uint32_t bits = 0x40491fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), tfloat32_t()); +} + +/// Returns 2 * pi (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t two_pi() { + uint32_t bits = 0x40c91fdbu; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), tfloat32_t()); +} + +/// Returns pi / 2 (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half_pi() { + uint32_t bits = 0x3fc91fdbu; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), tfloat32_t()); +} + +/// Returns sqrt(pi) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_pi() { + uint32_t bits = 0x3fe2efc5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), tfloat32_t()); +} + +/// Returns sqrt(pi / 2) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_half_pi() { + uint32_t bits = 0x3fa07c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), tfloat32_t()); +} + +/// Returns sqrt(2 * pi) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_two_pi() { + uint32_t bits = 0x40207c99u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), tfloat32_t()); +} + +/// Returns sqrt(ln(4)) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_ln_four() { + uint32_t bits = 0x3f96c55fu; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), tfloat32_t()); +} + +/// Returns e, approximately 2.718... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t e() { + uint32_t bits = 0x402e0854u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), tfloat32_t()); +} + +/// Returns (1/2) (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half() { + uint32_t bits = 0x3f001000u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), tfloat32_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t root_two() { + uint32_t bits = 0x3fb514f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), tfloat32_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t half_root_two() { + uint32_t bits = 0x3f3514f3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), tfloat32_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t ln_two() { + uint32_t bits = 0x3f318218u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), tfloat32_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t ln_ln_two() { + uint32_t bits = 0xbebbb795u; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), tfloat32_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t third() { + uint32_t bits = 0x3eaabaabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), tfloat32_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t twothirds() { + uint32_t bits = 0x3f2abaabu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), tfloat32_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t pi_minus_three() { + uint32_t bits = 0x3e110daau; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), tfloat32_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for tfloat32_t) +template <> CUTLASS_HOST_DEVICE tfloat32_t four_minus_pi() { + uint32_t bits = 0x3f5bd095u; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), tfloat32_t()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for half_t + +/// Returns 1, the multiplicative identity element (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t one() { + uint16_t bits = 0x3c00u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), half_t()); +} + +/// Returns 0, the additive identity element (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t zero() { + uint16_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), half_t()); +} + +/// Returns 2 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t two() { + uint16_t bits = 0x4000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), half_t()); +} + +/// Returns pi, approximately 3.141 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t pi() { + uint16_t bits = 0x4248u; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), half_t()); +} + +/// Returns 2 * pi (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t two_pi() { + uint16_t bits = 0x4648u; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), half_t()); +} + +/// Returns pi / 2 (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half_pi() { + uint16_t bits = 0x3e48u; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), half_t()); +} + +/// Returns sqrt(pi) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_pi() { + uint16_t bits = 0x3f17u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), half_t()); +} + +/// Returns sqrt(pi / 2) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_half_pi() { + uint16_t bits = 0x3d03u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), half_t()); +} + +/// Returns sqrt(2 * pi) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_two_pi() { + uint16_t bits = 0x4103u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), half_t()); +} + +/// Returns sqrt(ln(4)) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_ln_four() { + uint16_t bits = 0x3cb6u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), half_t()); +} + +/// Returns e, approximately 2.718... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t e() { + uint16_t bits = 0x4170u; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), half_t()); +} + +/// Returns (1/2) (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half() { + uint16_t bits = 0x3800u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), half_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t root_two() { + uint16_t bits = 0x3da8u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), half_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t half_root_two() { + uint16_t bits = 0x39a8u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), half_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t ln_two() { + uint16_t bits = 0x398cu; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), half_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t ln_ln_two() { + uint16_t bits = 0xb5ddu; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), half_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t third() { + uint16_t bits = 0x3555u; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), half_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t twothirds() { + uint16_t bits = 0x3955u; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), half_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t pi_minus_three() { + uint16_t bits = 0x3088u; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), half_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for half_t) +template <> CUTLASS_HOST_DEVICE half_t four_minus_pi() { + uint16_t bits = 0x3adeu; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), half_t()); +} + +///////////////////////////////////////////////////////////////////////////////////// + +// Specialization for bfloat16_t + +/// Returns 1, the multiplicative identity element (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t one() { + uint16_t bits = 0x3f80u; + return reinterpret_cast(bits); +} + +/// Returns 1, the multiplicative identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex one< complex >() { + return complex(one(), bfloat16_t()); +} + +/// Returns 0, the additive identity element (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t zero() { + uint16_t bits = 0x0u; + return reinterpret_cast(bits); +} + +/// Returns 0, the additive identity element (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex zero< complex >() { + return complex(zero(), bfloat16_t()); +} + +/// Returns 2 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t two() { + uint16_t bits = 0x4000u; + return reinterpret_cast(bits); +} + +/// Returns 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two< complex >() { + return complex(two(), bfloat16_t()); +} + +/// Returns pi, approximately 3.141 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t pi() { + uint16_t bits = 0x4049u; + return reinterpret_cast(bits); +} + +/// Returns pi, approximately 3.141 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi< complex >() { + return complex(pi(), bfloat16_t()); +} + +/// Returns 2 * pi (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t two_pi() { + uint16_t bits = 0x40c9u; + return reinterpret_cast(bits); +} + +/// Returns 2 * pi (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex two_pi< complex >() { + return complex(two_pi(), bfloat16_t()); +} + +/// Returns pi / 2 (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half_pi() { + uint16_t bits = 0x3fc9u; + return reinterpret_cast(bits); +} + +/// Returns pi / 2 (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_pi< complex >() { + return complex(half_pi(), bfloat16_t()); +} + +/// Returns sqrt(pi) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_pi() { + uint16_t bits = 0x3fe3u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_pi< complex >() { + return complex(root_pi(), bfloat16_t()); +} + +/// Returns sqrt(pi / 2) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_half_pi() { + uint16_t bits = 0x3fa0u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(pi / 2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_half_pi< complex >() { + return complex(root_half_pi(), bfloat16_t()); +} + +/// Returns sqrt(2 * pi) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_two_pi() { + uint16_t bits = 0x4020u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2 * pi) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two_pi< complex >() { + return complex(root_two_pi(), bfloat16_t()); +} + +/// Returns sqrt(ln(4)) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_ln_four() { + uint16_t bits = 0x3f97u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(ln(4)) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_ln_four< complex >() { + return complex(root_ln_four(), bfloat16_t()); +} + +/// Returns e, approximately 2.718... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t e() { + uint16_t bits = 0x402eu; + return reinterpret_cast(bits); +} + +/// Returns e, approximately 2.718... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex e< complex >() { + return complex(e(), bfloat16_t()); +} + +/// Returns (1/2) (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half() { + uint16_t bits = 0x3f00u; + return reinterpret_cast(bits); +} + +/// Returns (1/2) (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half< complex >() { + return complex(half(), bfloat16_t()); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t root_two() { + uint16_t bits = 0x3fb5u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2), approximately 1.414... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex root_two< complex >() { + return complex(root_two(), bfloat16_t()); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t half_root_two() { + uint16_t bits = 0x3f35u; + return reinterpret_cast(bits); +} + +/// Returns sqrt(2)/2, approximately 0.707... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex half_root_two< complex >() { + return complex(half_root_two(), bfloat16_t()); +} + +/// Returns ln(2), approximately 0.693... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t ln_two() { + uint16_t bits = 0x3f31u; + return reinterpret_cast(bits); +} + +/// Returns ln(2), approximately 0.693... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_two< complex >() { + return complex(ln_two(), bfloat16_t()); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t ln_ln_two() { + uint16_t bits = 0xbebcu; + return reinterpret_cast(bits); +} + +/// Returns ln(ln(2)), approximately -0.3665... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex ln_ln_two< complex >() { + return complex(ln_ln_two(), bfloat16_t()); +} + +/// Returns 1/3, approximately 0.333... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t third() { + uint16_t bits = 0x3eabu; + return reinterpret_cast(bits); +} + +/// Returns 1/3, approximately 0.333... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex third< complex >() { + return complex(third(), bfloat16_t()); +} + +/// Returns 2/3, approximately 0.666... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t twothirds() { + uint16_t bits = 0x3f2bu; + return reinterpret_cast(bits); +} + +/// Returns 2/3, approximately 0.666... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex twothirds< complex >() { + return complex(twothirds(), bfloat16_t()); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t pi_minus_three() { + uint16_t bits = 0x3e11u; + return reinterpret_cast(bits); +} + +/// Returns pi - 3, approximately 0.1416... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex pi_minus_three< complex >() { + return complex(pi_minus_three(), bfloat16_t()); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for bfloat16_t) +template <> CUTLASS_HOST_DEVICE bfloat16_t four_minus_pi() { + uint16_t bits = 0x3f5cu; + return reinterpret_cast(bits); +} + +/// Returns 4 - pi, approximately 0.858... (specialization for complex) +template <> CUTLASS_HOST_DEVICE complex four_minus_pi< complex >() { + return complex(four_minus_pi(), bfloat16_t()); +} +/////////////////////////////////////////////////////////////////////////////////// + +} // namespace constants +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 82613c2450..181e3116e8 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -439,6 +439,12 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) { return Coord<4>(values); } +/// Helper to make a 5-element coordinate +CUTLASS_HOST_DEVICE +Coord<5> make_Coord(int _0, int _1, int _2, int _3, int _4) { + int values[5] = {_0, _1, _2, _3, _4}; + return Coord<5>(values); +} //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index a87ecfa707..1f624f1fa8 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -31,18 +31,43 @@ #include #include +#include "cutlass/array.h" #include "cutlass/coord.h" #include "cutlass/numeric_types.h" #include "cutlass/matrix_shape.h" #include "cutlass/layout/pitch_linear.h" +#include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Output operator for CUDA built-in dim3 type +inline std::ostream &operator<<(std::ostream &out, dim3 d) { + return out << d.x << ", " << d.y << ", " << d.z; +} + +/// Output operator for CUDA built-in error type +inline std::ostream &operator<<(std::ostream &out, cudaError_t error) { + return out << cudaGetErrorString(error); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass { /////////////////////////////////////////////////////////////////////////////////////////////////// // stream operators for cutlass namespace // /////////////////////////////////////////////////////////////////////////////////////////////////// +template +inline +std::ostream& operator<<(std::ostream& out, Array const& v) { + for (int i = 0; i < Rank; ++i) { + out << (i ? ", " : "") << v[i]; + } + return out; +} + template inline std::ostream& operator<<(std::ostream& out, Coord const& coord) { @@ -115,7 +140,7 @@ inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scal /// Default printing to ostream for MatrixShape template inline -std::ostream & operator<<(std::ostream &out, cutlass::MatrixShape const &matrix_shape) { +std::ostream & operator<<(std::ostream &out, MatrixShape const &matrix_shape) { out << "cutlass::MatrixShape::(kRow, kColumn) {" << cutlass::MatrixShape::kRow <<"," << cutlass::MatrixShape::kColumn <<"}"; @@ -130,7 +155,7 @@ namespace gemm { /// Default printing to ostream for GemmShape template inline -std::ostream & operator<<(std::ostream &out, cutlass::gemm::GemmShape const &gemm_shape) { +std::ostream & operator<<(std::ostream &out, GemmShape const &gemm_shape) { out << "cutlass::GemmShape::(kM, kN, kK) {" << cutlass::gemm::GemmShape::kM <<"," << cutlass::gemm::GemmShape::kN <<"," @@ -150,7 +175,7 @@ namespace layout { /// Default printing to ostream for PitchLinearShape template < int Contiguous, int Strided> inline -std::ostream & operator<<(std::ostream &out, cutlass::layout::PitchLinearShape const &pitch_linear_shape) { +std::ostream & operator<<(std::ostream &out, PitchLinearShape const &pitch_linear_shape) { out << "cutlass::layout::PitchLinearShape::(kContiguous, kStrided) {" << cutlass::layout::PitchLinearShape::kContiguous <<"," << cutlass::layout::PitchLinearShape::kStrided <<"}"; diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 860dc3e566..622f037b40 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -125,13 +125,6 @@ static char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Debug { - typename T::X x; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - static const int NUM_THREADS_PER_WARP = 32; static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2; static const int NUM_THREADS_PER_QUAD = 4; @@ -143,7 +136,7 @@ static const int NUM_THREADS_PER_QUAD_PAIR = NUM_THREADS_PER_QUAD * 2; CUTLASS_DEVICE int LaneId() { int ret; - asm ("mov.u32 %0, %%laneid;" : "=r"(ret)); + asm ("mov.u32 %0, %%laneid;" : "=r"(ret) : ); return ret; } @@ -151,7 +144,7 @@ int LaneId() { CUTLASS_DEVICE int SmId() { int ret; - asm ("mov.u32 %0, %%smid;" : "=r"(ret)); + asm ("mov.u32 %0, %%smid;" : "=r"(ret) : ); return ret; } diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index c0f42146e6..d352ea5a64 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -31,9 +31,8 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" - +#include "cutlass/constants.h" #include "cutlass/complex.h" - #include "cutlass/array.h" #include "cutlass/half.h" #include "cutlass/functional.h" @@ -108,6 +107,40 @@ struct Sigmoid > { } }; +// GELU operator +template +struct GELU { + CUTLASS_HOST_DEVICE + T operator()(T const &scalar) const { + return T(cutlass::constants::half() * scalar * + (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() ))); + } +}; + +template <> +struct GELU { + CUTLASS_HOST_DEVICE + float operator()(float const &scalar) const { + return cutlass::constants::half() * scalar * + (cutlass::constants::one() + erff( scalar / cutlass::constants::root_two() )); + } +}; + +template +struct GELU > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &rhs) const { + Array y; + GELU gelu_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < int(rhs.size()); ++i) { + y[i] = gelu_op(rhs[i]); + } + + return y; + } +}; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/linear_combination.h b/include/cutlass/epilogue/thread/linear_combination.h index 8b5f6ead1c..31c91643c1 100644 --- a/include/cutlass/epilogue/thread/linear_combination.h +++ b/include/cutlass/epilogue/thread/linear_combination.h @@ -95,6 +95,13 @@ class LinearCombination { } + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha + ): alpha(alpha), beta(0), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr, @@ -102,6 +109,13 @@ class LinearCombination { ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(nullptr) { + + } }; private: diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index 25611bd36c..d1a0e6d908 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -236,6 +236,18 @@ class LinearCombinationClamp { using ElementAccumulator = int; using ElementCompute = float; + static_assert( + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "This elementwise op expects the output to be int."); + static int const kCount = Count; using FragmentOutput = Array; @@ -392,8 +404,9 @@ class LinearCombinationClamp { /// /// D = alpha * accumulator + beta * source + uniform /// -/// Note: The below method only works for small k dimensions. The default -/// approach is above +/// Note: The below method only when problem_size_K <= 256 for signed int8 gemm +/// or problem_size_K <= 128 for unsigned int8 gemm. The default approach is +/// above. /// TODO: Add logic to fallback to the default approach template < /// Data type used to load and store< tensors @@ -408,6 +421,18 @@ class FastLinearCombinationClamp { using ElementAccumulator = int; using ElementCompute = float; + static_assert( + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value, + "This elementwise op expects the output to be int."); + static int const kCount = Count; using FragmentOutput = Array; diff --git a/include/cutlass/epilogue/thread/linear_combination_gelu.h b/include/cutlass/epilogue/thread/linear_combination_gelu.h new file mode 100644 index 0000000000..30b6213478 --- /dev/null +++ b/include/cutlass/epilogue/thread/linear_combination_gelu.h @@ -0,0 +1,206 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with GELU operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/epilogue/thread/activation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LinearCombinationGELU { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + + ElementCompute alpha; ///< scales accumulators + ElementCompute beta; ///< scales source tensor + ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory + ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + alpha(ElementCompute(1)), + beta(ElementCompute(0)), + alpha_ptr(nullptr), + beta_ptr(nullptr) { } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute alpha, + ElementCompute beta + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + + } + + CUTLASS_HOST_DEVICE + Params( + ElementCompute const *alpha_ptr, + ElementCompute const *beta_ptr + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + + } + }; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LinearCombinationGELU(Params const ¶ms) { + + alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); + beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); + } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return beta_ != ElementCompute(0); + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition) { + if (k_partition) { + beta_ = ElementCompute(1); + } + } + + /// Computes: D = gelu( alpha * accumulator + beta * source ) + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + // Convert source to interal compute numeric type + NumericArrayConverter source_converter; + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + GELU gelu; + + intermediate = mul_add_source(beta_, converted_source); // X = beta * C + uniform + intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X + + intermediate = gelu(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } + + /// Computes: D = gelu( alpha * accumulator ) + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + + ComputeFragment intermediate; + + multiplies mul_add_accumulator; + GELU gelu; + + intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum + + intermediate = gelu(intermediate); + + // Convert to destination numeric type + NumericArrayConverter destination_converter; + + return destination_converter(intermediate); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h index 67fccf05c2..5c12f21680 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_complex_tensor_op.h @@ -78,7 +78,8 @@ template < int ElementsPerAccess, /// Multiply-add operator /// Selects between (arch::OpMultiplyAddComplex, arch::OpMultiplyGaussianComplex) - typename Operator_ = arch::OpMultiplyAddComplex> + typename Operator_ = arch::OpMultiplyAddComplex +> struct DefaultEpilogueComplexTensorOp { using Shape = Shape_; @@ -87,7 +88,6 @@ struct DefaultEpilogueComplexTensorOp { using OutputOp = OutputOp_; static int const kElementsPerAccess = ElementsPerAccess; using Operator = Operator_; - using ElementOutput = typename OutputOp::ElementOutput; using LayoutC = typename WarpMmaTensorOp::LayoutC; using ElementAccumulator = typename WarpMmaTensorOp::ElementC; @@ -164,7 +164,8 @@ template < > struct DefaultEpilogueComplexTensorOp { + arch::OpMultiplyAddGaussianComplex +> { using Shape = Shape_; using WarpMmaTensorOp = WarpMmaTensorOp_; @@ -172,7 +173,6 @@ struct DefaultEpilogueComplexTensorOp { /// Number of operations using OperatorCount = MatrixShape< - WarpShape::kM / OperatorShape::kM, - WarpShape::kN / OperatorShape::kN + (WarpShape::kM + OperatorShape::kM - 1) / OperatorShape::kM, + (WarpShape::kN + OperatorShape::kN - 1) / OperatorShape::kN >; // @@ -70,6 +70,8 @@ struct TensorOpPolicy { static int const kElementsPerAccess = 2; static int const kRowsPerIteration = 8; + static bool const kDivisible = + !(WarpShape::kM % OperatorShape::kM) && !(WarpShape::kN % OperatorShape::kN); // // Derived quantities diff --git a/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h index 04c361f5ee..33cee0d375 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h @@ -29,6 +29,7 @@ #pragma once #include "cutlass/array.h" +#include "cutlass/tensor_ref.h" #include "cutlass/layout/matrix.h" #include "cutlass/layout/pitch_linear.h" @@ -116,6 +117,9 @@ class TileIteratorTensorOp(ref.data())), - layout_(ref.stride()[0] / Policy::kElementsPerAccess) { + layout_(ref.stride()[0] / Policy::kElementsPerAccess) { int quad_id = (lane_id / Detail::kLanesInQuad); int lane_in_quad = (lane_id % Detail::kLanesInQuad); - pointer_ += layout_({quad_id, lane_in_quad}); + thread_offset_ = { + quad_id, lane_in_quad * Policy::kElementsPerAccess + }; + + pointer_ += layout_({thread_offset_.row(), thread_offset_.column() / Policy::kElementsPerAccess}); } /// Adds a pointer offset @@ -148,9 +156,16 @@ class TileIteratorTensorOp +class TileIteratorTensorOpCanonical { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using Element = Element_; + using Layout = Layout_; + + using TensorRef = TensorRef; ///< Tensor Reference object + using TensorCoord = MatrixCoord; ///< Logical coordinate in referenced tensor + using Index = typename TensorRef::Index; + using LongIndex = typename TensorRef::LongIndex; + + using Policy = TensorOpPolicy; + + static int const kAccessSize = 1; + static int const kAccessCount = Policy::kElementsPerAccess / kAccessSize; + + /// Shape of the tile in memory + using Shape = MatrixShape< + Policy::kRowsPerIteration, + WarpShape::kN + >; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + Element, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + //using AccumulatorTile = typename Operator::FragmentC; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + + // Internal constants + struct Detail { + static int const kLanesInQuad = 4; + }; + + /// Padding quantity + using Padding = MatrixShape< + 0, + Detail::kLanesInQuad * Policy::kElementsPerAccess>; + +private: + + /// Storage type for accessing memory + using AccessType = AlignedArray; + + // + // Data members + // + + /// Internal pointer to memory + AccessType *pointer_; + + /// Internal layout object + Layout layout_; + + /// Guard to indicate whether the shape is divisible + bool divisible_; + + /// Extent of the output tensor + MatrixCoord extent_; + + /// Thread offset + MatrixCoord thread_offset_; + +public: + + /// Default constructor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical(): pointer_(nullptr) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical( + TensorRef const &ref, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0]), + divisible_(true), + extent_(WarpShape::kM, WarpShape::kN) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + thread_offset_ = { + quad_id, lane_in_quad * Policy::kElementsPerAccess + }; + + pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical( + TensorRef const &ref, + TensorCoord const &extent, + unsigned lane_id + ): + pointer_(reinterpret_cast(ref.data())), + layout_(ref.stride()[0]), + divisible_(false), + extent_(extent) { + + int quad_id = (lane_id / Detail::kLanesInQuad); + int lane_in_quad = (lane_id % Detail::kLanesInQuad); + + thread_offset_ = { + quad_id, lane_in_quad * Policy::kElementsPerAccess + }; + + pointer_ += layout_({thread_offset_.row(), thread_offset_.column()}); + } + + /// Adds a pointer offset + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & add_pointer_offset(Index pointer_offset) { + pointer_ += pointer_offset; + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & add_tile_offset(TensorCoord const &tile_offset) { + + MatrixCoord coord_offset( + tile_offset.row() * Shape::kRow, + tile_offset.column() * Shape::kColumn + ); + + thread_offset_ += coord_offset; + + pointer_ += layout_({ + coord_offset.row(), + coord_offset.column() + }); + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + /// Store + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) { + + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int a = 0; a < kAccessCount; ++a) { + + int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; + int frag_idx = n * kAccessCount + a; + + int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + pointer_[ptr_idx] = frag_ptr[frag_idx]; + } + } + } + } + + /// Store + CUTLASS_HOST_DEVICE + void store(Fragment const &frag) { + store_with_pointer_offset(frag, 0); + } + + /// Load + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int a = 0; a < kAccessCount; ++a) { + + int ptr_idx = n * Detail::kLanesInQuad * kAccessCount + pointer_offset + a; + int frag_idx = n * kAccessCount + a; + + int col = thread_offset_.column() + n * Detail::kLanesInQuad * Policy::kElementsPerAccess + a; + + if (divisible_ || (thread_offset_.row() < extent_.row() && col < extent_.column())) { + frag_ptr[frag_idx] = pointer_[ptr_idx]; + } + } + } + } + + /// Load + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + CUTLASS_HOST_DEVICE + TileIteratorTensorOpCanonical & operator++() { + return add_tile_offset({1, 0}); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h index 8ffb5ec128..1754f58016 100644 --- a/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h +++ b/include/cutlass/epilogue/warp/tile_iterator_volta_tensor_op.h @@ -33,6 +33,7 @@ #include "cutlass/layout/pitch_linear.h" #include "cutlass/epilogue/warp/tensor_op_policy.h" +#include "cutlass/epilogue/warp/volta_tensor_op_policy.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 036b08e23b..978e614b1d 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -29,6 +29,7 @@ #include #else #include +#include #endif #include "cutlass/cutlass.h" @@ -40,6 +41,8 @@ namespace cutlass { +///////////////////////////////////////////////////////////////////////////////////////////////// + /****************************************************************************** * Static math utilities ******************************************************************************/ @@ -136,6 +139,19 @@ CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) { return temp ? (a / temp * b) : 0; } +/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b +CUTLASS_HOST_DEVICE +constexpr int round_up(int a, int b) { + return ((a + b - 1) / b) * b; +} + +/// Returns the ceiling of (a / b) +CUTLASS_HOST_DEVICE +constexpr int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + + /** * log2 computation, what's the * difference between the below codes and @@ -189,7 +205,6 @@ void fast_divmod(int& quo, int& rem, int src, int div, unsigned int mul, unsigne // The remainder. rem = src - (quo * div); - } // For long int input @@ -206,17 +221,56 @@ void fast_divmod(int& quo, int64_t& rem, int64_t src, int div, unsigned int mul, rem = src - (quo * div); } -/// Returns the smallest value in the half-open range [a, a+b) that is a multiple of b -CUTLASS_HOST_DEVICE -int round_up(int a, int b) { - return ((a + b - 1) / b) * b; -} +/// Object to encapsulate the fast division+modulus operation. +/// +/// This object precomputes two values used to accelerate the computation and is best used +/// when the divisor is a grid-invariant. In this case, it may be computed in host code and +/// marshalled along other kernel arguments using the 'Params' pattern. +/// +/// Example: +/// +/// +/// int quotient, remainder, dividend, divisor; +/// +/// FastDivmod divmod(divisor); +/// +/// divmod(quotient, remainder, dividend); +/// +/// // quotient = (dividend / divisor) +/// // remainder = (dividend % divisor) +/// +struct FastDivmod { + + int divisor; + unsigned int multiplier; + unsigned int shift_right; + + /// Construct the FastDivmod object, in host code ideally. + /// + /// This precomputes some values based on the divisor and is computationally expensive. + + CUTLASS_HOST_DEVICE + FastDivmod(): divisor(0), multiplier(0), shift_right(0) { } + + CUTLASS_HOST_DEVICE + FastDivmod(int divisor_): divisor(divisor_) { + find_divisor(multiplier, shift_right, divisor); + } -/// Returns the ceiling of (a / b) -CUTLASS_HOST_DEVICE -int ceil_div(int a, int b) { - return (a + b - 1) / b; -} + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + void operator()(int "ient, int &remainder, int dividend) const { + fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + void operator()(int "ient, int64_t &remainder, int64_t dividend) const { + fast_divmod(quotient, remainder, dividend, divisor, multiplier, shift_right); + } +}; /****************************************************************************** * Min/Max @@ -242,4 +296,117 @@ constexpr int const_max(int a, int b) { return (b > a ? b : a); } +CUTLASS_HOST_DEVICE +float fast_cos(float theta) { + #if defined(__CUDA_ARCH__) + return ::cosf(theta); + #else + return std::cos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_cos(double theta) { + #if defined(__CUDA_ARCH__) + return ::cos(theta); + #else + return std::cos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_sin(float theta) { + #if defined(__CUDA_ARCH__) + return ::sinf(theta); + #else + return std::sin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_sin(double theta) { + #if defined(__CUDA_ARCH__) + return ::sin(theta); + #else + return std::sin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_acos(float theta) { + #if defined(__CUDA_ARCH__) + return ::acosf(theta); + #else + return std::acos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_acos(double theta) { + #if defined(__CUDA_ARCH__) + return ::acos(theta); + #else + return std::acos(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_asin(float theta) { + #if defined(__CUDA_ARCH__) + return ::asinf(theta); + #else + return std::asin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_asin(double theta) { + #if defined(__CUDA_ARCH__) + return ::asin(theta); + #else + return std::asin(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_sqrt(float theta) { + #if defined(__CUDA_ARCH__) + return ::sqrtf(theta); + #else + return std::sqrt(theta); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_sqrt(double theta) { + #if defined(__CUDA_ARCH__) + return ::sqrt(theta); + #else + return std::sqrt(theta); + #endif +} + +CUTLASS_HOST_DEVICE +float fast_log(float x) { + #if defined(__CUDA_ARCH__) + return ::logf(x); + #else + return std::log(x); + #endif +} + +CUTLASS_HOST_DEVICE +double fast_log(double x) { + #if defined(__CUDA_ARCH__) + return ::log(x); + #else + return std::log(x); + #endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 13ee7f542b..90cf394941 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -32,9 +32,7 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" - #include "cutlass/complex.h" - #include "cutlass/array.h" #include "cutlass/half.h" @@ -69,6 +67,82 @@ struct multiplies { } }; +/// Squares with optional conversion +template +struct square { + CUTLASS_HOST_DEVICE + Output operator()(T lhs) const { + multiplies mul_op; + + Output y = Output(lhs); + return mul_op(y, y); + } +}; + +/// Returns the magnitude squared of an element. +template +struct magnitude_squared { + CUTLASS_HOST_DEVICE + Output operator()(T lhs) const { + multiplies mul_op; + + Output y = Output(lhs); + return mul_op(y, y); + } +}; + +/// Squares with optional conversion +template +struct magnitude_squared, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()); + Output y_i = Output(lhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Computes the square of a difference with optional conversion +template +struct square_difference { + CUTLASS_HOST_DEVICE + Output operator()(T lhs, T rhs) const { + multiplies mul_op; + + Output y = Output(lhs) - Output(rhs); + return mul_op(y, y); + } +}; + +/// Computes the square of a difference with optional conversion +template +struct magnitude_squared_difference { + CUTLASS_HOST_DEVICE + Output operator()(T lhs, T rhs) const { + multiplies mul_op; + + Output y = Output(lhs) - Output(rhs); + return mul_op(y, y); + } +}; + +/// Computes the square of a difference with optional conversion +template +struct magnitude_squared_difference, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs, complex rhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()) - Output(rhs.real()); + Output y_i = Output(lhs.imag()) - Output(rhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + template struct divides { CUTLASS_HOST_DEVICE diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h new file mode 100644 index 0000000000..df2a141cd1 --- /dev/null +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -0,0 +1,517 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/sparse_gemm.h" + +#include "cutlass/gemm/kernel/default_gemm_sparse.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! Gemm device-level operator. This is an interface to efficient CUTLASS GEMM kernels that may + be invoked from host code. + + The contributions of this class are: + + 1. At compile time, it maps data types and high-level structural parameters onto + specific CUTLASS components. + + 2. At runtime, it maps logical arguments to GEMM problems to kernel parameters. + + 3. At runtime, it launches kernels on the device. + + The intent is to provide a convenient mechanism for interacting with most plausible GEMM + configurations for each supported architecture. Consequently, not all parameters are exposed + to the top-level interface. Rather, sensible defaults at each level of the CUTLASS hierarchy + are selected to tradeoff simplicity of the interface with flexibility. We expect + most configurations to be specified at this level. Applications with more exotic requirements + may construct their kernels of interest using CUTLASS components at the threadblock, warp, + and thread levels of abstraction. + + CUTLASS exposes computations using the functor design pattern in which objects compose some + internal state with an overloaded function call operator. This enables decoupling of + initialization from execution, possibly reducing overhead during steady state phases of + application execution. + + CUTLASS device-level operators expose an Arguments structure encompassing each logical + input to the computation. This is distinct from the kernel-level Params structure pattern + which contains application-specific precomputed state needed by the device code. + + Example of a CUTLASS GEMM operator implementing the functionality of cuBLAS's SGEMM NN + is as follows: + + // + // Instantiate the CUTLASS GEMM operator. + // + + cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor + > gemm_op; + + // + // Launch the GEMM operation on the device + // + + cutlass::Status status = gemm_op({ + {m, n, k}, // GemmCoord problem_size, + {A, lda}, // TensorRef ref_A, + {B, ldb}, // TensorRef ref_B, + {C, ldc}, // TensorRef ref_C, + {D, ldd}, // TensorRef ref_D, + {alpha, beta} // EpilogueOutputOp::Params epilogue_op_params + }); + + + A simplified view of the template is listed below. + + template < + /// Element type for A matrix operand + typename ElementA, + + /// Layout type for A matrix operand + typename LayoutA, + + /// Element type for B matrix operand + typename ElementB, + + /// Layout type for B matrix operand + typename LayoutB, + + /// Element type for C and D matrix operands + typename ElementC, + + /// Layout type for C and D matrix operands + typename LayoutC, + + /// Element type for internal accumulation + typename ElementAccumulator, + + /// Operator class tag + typename OperatorClass, + + /// Tag indicating architecture to tune for + typename ArchTag, + + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + + /// Epilogue output operator + typename EpilogueOutputOp, + + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + + /// Number of stages used in the pipelined mainloop + int Stages + > + class Gemm; +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Whether Beta is zero or not + bool IsBetaZero = false> +class SparseGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static bool const kIsBetaZero = IsBetaZero; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Define the kernel + using GemmKernel = typename kernel::DefaultSparseGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + kIsBetaZero + >::GemmKernel; + + using ElementE = typename GemmKernel::ElementE; + + using LayoutE = typename GemmKernel::LayoutE; + + static int const kAlignmentE = 128 / sizeof_bits::value; + + static int const kSparse = GemmKernel::kSparse; + static int const kMetaSizeInBits = GemmKernel::kMetaSizeInBits; + static int const kElementsPerElementE = GemmKernel::kElementsPerElementE; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_E; + typename EpilogueOutputOp::Params epilogue; + int split_k_slices; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_C_, + TensorRef ref_D_, + TensorRef ref_E_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params(), + int split_k_slices = 1 + ): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ref_E(ref_E_), + epilogue(epilogue_), + split_k_slices(split_k_slices) { + + } + }; + +private: + + /// Kernel parameters object + typename GemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + SparseGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ref_E.non_const_ref() + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.ref_E.non_const_ref(), + args.epilogue, + static_cast(workspace) + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.ref_E.reset(args.ref_E.non_const_ref().data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = cudaFuncSetAttribute( + Kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 18ccb3469a..fc52a08d0f 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -30,6 +30,8 @@ #pragma once +#include + #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" #include "cutlass/arch/arch.h" @@ -42,6 +44,8 @@ #include "cutlass/gemm/kernel/default_gemm_universal.h" #include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/trace.h" + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -121,13 +125,30 @@ class GemmUniversalBase { /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } return GemmKernel::can_implement(args); } /// Gets the workspace size static size_t get_workspace_size(Arguments const &args) { - + + CUTLASS_TRACE_HOST("GemmUniversalBase::get_workspace_size()"); + size_t workspace_bytes = 0; // Determine grid shape @@ -151,28 +172,41 @@ class GemmUniversalBase { workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); } + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + return workspace_bytes; } /// Computes the grid shape static dim3 get_grid_shape(Arguments const &args) { + CUTLASS_TRACE_HOST("GemmUniversalBase::get_grid_shape()"); + ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_tiled_shape; int gemm_k_size = 0; get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST( + " grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); - return threadblock_swizzle.get_grid_shape(grid_tiled_shape); + return result; } /// Computes the maximum number of active blocks per multiprocessor static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBase::maximum_active_blocks()"); + int max_active_blocks = -1; int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + if (smem_size <= (48 << 10)) { cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( @@ -182,6 +216,7 @@ class GemmUniversalBase { smem_size); if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); return max_active_blocks; } } @@ -195,6 +230,11 @@ class GemmUniversalBase { 0); if (result != cudaSuccess) { + + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + return -1; } @@ -216,27 +256,43 @@ class GemmUniversalBase { smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); } - return std::min(max_active_blocks, smem_capacity / smem_size); + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; } + CUTLASS_TRACE_HOST(" returning internal error"); + return -1; } /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + size_t workspace_bytes = get_workspace_size(args); + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + if (workspace_bytes) { if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; } if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + return Status::kErrorInternal; } } @@ -262,6 +318,8 @@ class GemmUniversalBase { /// Lightweight update given a subset of arguments Status update(Arguments const &args, void *workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase()::update() - workspace: " << workspace); + size_t workspace_bytes = get_workspace_size(args); if (workspace_bytes && !workspace) { @@ -275,6 +333,7 @@ class GemmUniversalBase { /// Runs the kernel using initialized state. Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBase::run()"); ThreadblockSwizzle threadblock_swizzle; @@ -302,11 +361,19 @@ class GemmUniversalBase { } } + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block + << "), SMEM: " << smem_size << " bytes"); + cutlass::Kernel<<>>(params_); result = cudaGetLastError(); - return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; } /// Runs the kernel using initialized state. diff --git a/include/cutlass/gemm/gemm.h b/include/cutlass/gemm/gemm.h index 78d0a6da6f..51f535f7c1 100644 --- a/include/cutlass/gemm/gemm.h +++ b/include/cutlass/gemm/gemm.h @@ -96,7 +96,7 @@ struct GemmCoord : public Coord<3, int> { /// Integer-valued index typedef int Index; - /// Base type is a Coord of rank=4 + /// Base type is a Coord of rank=3 typedef Coord<3, Index> Base; /// GEMM M dimension - rows of the output C matrix @@ -274,7 +274,7 @@ struct BatchedGemmCoord : public Coord<4, int> { /// GEMM K dimension - inner dimension of the GEMM problem static int const kK = 2; - /// GEMM K dimension - inner dimension of the GEMM problem + /// GEMM Batch dimension - inner dimension of the GEMM problem static int const kBatch = 3; // diff --git a/include/cutlass/gemm/kernel/default_gemm_sparse.h b/include/cutlass/gemm/kernel/default_gemm_sparse.h new file mode 100644 index 0000000000..9c43666fe0 --- /dev/null +++ b/include/cutlass/gemm/kernel/default_gemm_sparse.h @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm.h" +#include "cutlass/gemm/kernel/sparse_gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#include "cutlass/gemm/threadblock/default_sparse_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Beta is zero or not + bool IsBetaZero = false> +struct DefaultSparseGemm; + +//////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator> +struct DefaultSparseGemm { + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultSparseMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::SparseGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/include/cutlass/gemm/kernel/gemm.h b/include/cutlass/gemm/kernel/gemm.h index 6700659a1f..c3aa6f8f7a 100644 --- a/include/cutlass/gemm/kernel/gemm.h +++ b/include/cutlass/gemm/kernel/gemm.h @@ -175,7 +175,8 @@ struct Gemm { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -252,7 +253,8 @@ struct Gemm { // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_array.h b/include/cutlass/gemm/kernel/gemm_array.h index f63571b023..8cf25fb7df 100644 --- a/include/cutlass/gemm/kernel/gemm_array.h +++ b/include/cutlass/gemm/kernel/gemm_array.h @@ -133,7 +133,8 @@ struct GemmArray { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -207,7 +208,8 @@ struct GemmArray { // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_batched.h b/include/cutlass/gemm/kernel/gemm_batched.h index eb638375c0..ac8f5a3799 100644 --- a/include/cutlass/gemm/kernel/gemm_batched.h +++ b/include/cutlass/gemm/kernel/gemm_batched.h @@ -140,7 +140,8 @@ struct GemmBatched { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -219,7 +220,8 @@ struct GemmBatched { // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_pipelined.h b/include/cutlass/gemm/kernel/gemm_pipelined.h index 6caa0eae31..02c7ba254b 100644 --- a/include/cutlass/gemm/kernel/gemm_pipelined.h +++ b/include/cutlass/gemm/kernel/gemm_pipelined.h @@ -66,7 +66,7 @@ __global__ void GemmPipelined( // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape); if (grid_tiled_shape.m() <= tb_tile_offset.m() || grid_tiled_shape.n() <= tb_tile_offset.n()) { @@ -131,7 +131,7 @@ __global__ void GemmPipelined( warp_id, lane_id); - tb_tile_offset = threadblock_swizzle.get_tile_offset(); + tb_tile_offset = threadblock_swizzle.get_tile_offset(grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h index e05112569b..ab888940f2 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -419,7 +419,8 @@ struct GemmPlanarComplex { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -549,7 +550,8 @@ struct GemmPlanarComplex { // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index 00841d4692..0023bd583b 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -376,7 +376,8 @@ struct GemmPlanarComplexArray { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || diff --git a/include/cutlass/gemm/kernel/gemm_splitk_parallel.h b/include/cutlass/gemm/kernel/gemm_splitk_parallel.h index 973897521f..72ca5a4743 100644 --- a/include/cutlass/gemm/kernel/gemm_splitk_parallel.h +++ b/include/cutlass/gemm/kernel/gemm_splitk_parallel.h @@ -128,7 +128,8 @@ struct GemmSplitKParallel { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -205,7 +206,8 @@ struct GemmSplitKParallel { // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index 6efd50a7fd..e6e3c97bed 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -36,6 +36,8 @@ #include "cutlass/complex.h" #include "cutlass/semaphore.h" +#include "cutlass/trace.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -154,6 +156,7 @@ struct GemmUniversal { batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), batch_stride_D(batch_stride_D), lda(lda), ldb(ldb), ldc(ldc), ldd(ldd) { + CUTLASS_TRACE_HOST("GemmUniversal::Arguments::Arguments() - problem_size: " << problem_size); } /// Returns arguments for the transposed problem @@ -252,6 +255,7 @@ struct GemmUniversal { batch_stride_D(args.batch_stride_D), semaphore(static_cast(workspace)) { + CUTLASS_TRACE_HOST("GemmUniversal::Params::Params() - problem_size: " << problem_size); } CUTLASS_HOST_DEVICE @@ -264,9 +268,16 @@ struct GemmUniversal { ptr_C = const_cast(args.ptr_C); ptr_D = args.ptr_D; + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_D = args.batch_stride_D; + output_op = args.epilogue; semaphore = static_cast(workspace); + + CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); } }; @@ -289,6 +300,8 @@ struct GemmUniversal { static Status can_implement( cutlass::gemm::GemmCoord const & problem_size) { + CUTLASS_TRACE_HOST("GemmUniversal::can_implement()"); + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; @@ -297,9 +310,12 @@ struct GemmUniversal { (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC)) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand"); return Status::kErrorMisalignedOperand; } + CUTLASS_TRACE_HOST(" returning kSuccess"); + return Status::kSuccess; } @@ -314,7 +330,8 @@ struct GemmUniversal { // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || @@ -421,7 +438,8 @@ struct GemmUniversal { // Masked tile iterators constructed from members // - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(); + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); //assume identity swizzle MatrixCoord threadblock_offset( diff --git a/include/cutlass/gemm/kernel/sparse_gemm.h b/include/cutlass/gemm/kernel/sparse_gemm.h new file mode 100644 index 0000000000..85e3839ce7 --- /dev/null +++ b/include/cutlass/gemm/kernel/sparse_gemm.h @@ -0,0 +1,392 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct SparseGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + static int const kSparse = Mma::kSparse; + static int const kMetaSizeInBits = Mma::kMetaSizeInBits; + static int const kMaxID2 = Mma::kMaxID2; + static int const kElementsPerElementE = Mma::kElementsPerElementE; + + using ElementE = typename Mma::ElementE; + using LayoutE = typename Mma::LayoutE; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename Mma::IteratorE::Params params_E; + typename Mma::IteratorE::TensorRef ref_E; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_iterations; + int gemm_k_size; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): semaphore(0), gemm_k_iterations(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename Mma::IteratorE::TensorRef ref_E, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + params_E(ref_E.layout()), + ref_E(ref_E), + output_op(output_op) { + + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + SparseGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename Mma::IteratorE::TensorRef ref_E) { + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + static int const kAlignmentE = Mma::IteratorE::AccessType::kElements; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_E, kAlignmentE)) { + return Status::kErrorMisalignedOperand; + } + + if ((problem_size.m() % kAlignmentA) || ((problem_size.k() / kSparse) % kAlignmentA) || + (problem_size.n() % kAlignmentB) || (problem_size.k() % kAlignmentB) || + (problem_size.m() % kAlignmentC) || (problem_size.n() % kAlignmentC) || + (problem_size.m() % kAlignmentE) || ((problem_size.k() / kSparse) % kAlignmentE)) { + + return Status::kErrorMisalignedOperand; + } + + // The k dimension has to be the multiple of the Threadblock k because out + // of bound meta data would be initialized to 0 by acync.zfill but 0 is not + // a valid meta data. + if (problem_size.k() % Mma::Shape::kK) { + return Status::kErrorMisalignedOperand; + } + + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the row reordering of operand E + static int const kAlignmentM = (sizeof(ElementE) == 2) ? 32 : 16; + + if (problem_size.m() % kAlignmentM) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size / kSparse, + }; + + cutlass::MatrixCoord tb_offset_B{ + threadblock_tile_offset.k() * params.gemm_k_size, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_E{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size / kSparse, + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_B.row() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A, B, and E operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k / kSparse}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + typename Mma::IteratorE iterator_E( + params.params_E, params.ref_E.data(), + {params.problem_size.m(), + problem_size_k / kSparse / kElementsPerElementE}, + thread_idx, tb_offset_E); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_E, accumulators); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/gemm/thread/mma_sm50.h b/include/cutlass/gemm/thread/mma_sm50.h index 04658f7bc0..6d52efb023 100644 --- a/include/cutlass/gemm/thread/mma_sm50.h +++ b/include/cutlass/gemm/thread/mma_sm50.h @@ -229,6 +229,16 @@ struct Mma< /// C operand storage using FragmentC = Array; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename MmaGeneric< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + Operator>::MmaOp; // // Methods // diff --git a/include/cutlass/gemm/thread/mma_sm60.h b/include/cutlass/gemm/thread/mma_sm60.h index 16d0d61c24..486497cb79 100644 --- a/include/cutlass/gemm/thread/mma_sm60.h +++ b/include/cutlass/gemm/thread/mma_sm60.h @@ -977,6 +977,30 @@ struct Mma< /// C operand storage using FragmentC = Array; + static bool const a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value; + static bool const b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value; + static bool const c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value; + static bool const c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value; + + static bool const m_mod2 = !(Shape::kM % 2); + static bool const n_mod2 = !(Shape::kN % 2); + static bool const k_mod2 = !(Shape::kK % 2); + + // HFMA based MMA optimizations are of 2 types : + // 1. Inner product + // 2. Outer product + // It is chosen based on LayoutC (for outer product gemm) or + // Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms) + // If all fails, we choose the generic MMA + static bool const use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2); + static bool const use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2); + static bool const use_optimized = (use_outer_prod || use_inner_prod); + + using ArchMmaOperator = typename platform::conditional< use_optimized, + detail::Mma_HFMA2, + MmaGeneric + >::type; + // // Methods // @@ -989,30 +1013,8 @@ struct Mma< FragmentB const & B, FragmentC const & C) { - constexpr bool a_row_major = platform::is_same< LayoutA, layout::RowMajor>::value; - constexpr bool b_column_major = platform::is_same< LayoutB, layout::ColumnMajor>::value; - constexpr bool c_row_major = platform::is_same< LayoutC, layout::RowMajor>::value; - constexpr bool c_column_major = platform::is_same< LayoutC, layout::ColumnMajor>::value; - - constexpr bool m_mod2 = !(Shape::kM % 2); - constexpr bool n_mod2 = !(Shape::kN % 2); - constexpr bool k_mod2 = !(Shape::kK % 2); - - // HFMA based MMA optimizations are of 2 types : - // 1. Inner product - // 2. Outer product - // It is chosen based on LayoutC (for outer product gemm) or - // Using LayoutA and LayoutB or shape=1x1x2K (for inner product gemms) - // If all fails, we choose the generic MMA - constexpr bool use_outer_prod = (c_column_major && m_mod2) || (c_row_major && n_mod2); - constexpr bool use_inner_prod = (a_row_major && b_column_major && k_mod2) || (Shape::kM==1 && Shape::kN==1 && k_mod2); - constexpr bool use_optimized = (use_outer_prod || use_inner_prod); - - typename platform::conditional< use_optimized, - detail::Mma_HFMA2, - MmaGeneric - >::type mma; - + ArchMmaOperator mma; + mma(D, A, B, C); } @@ -1086,6 +1088,8 @@ struct Mma< using FragmentB = Array; using FragmentC = Array; + using ArchMmaOperator = typename TransposeMma::ArchMmaOperator; + CUTLASS_HOST_DEVICE void operator()( FragmentC & D, diff --git a/include/cutlass/gemm/thread/mma_sm61.h b/include/cutlass/gemm/thread/mma_sm61.h index 83e31b2377..09fd356236 100644 --- a/include/cutlass/gemm/thread/mma_sm61.h +++ b/include/cutlass/gemm/thread/mma_sm61.h @@ -93,6 +93,19 @@ struct Mma< /// C operand storage using FragmentC = Array; + /// Underlying matrix multiply operator (concept: arch::Mma) + // Use 1x1x4 IDP4A sequence for bulk of computation + using ArchMmaOperator = arch::Mma< + gemm::GemmShape<1,1,4>, + 1, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + arch::OpMultiplyAdd>; + // // Methods // @@ -112,22 +125,11 @@ struct Mma< D = C; /// Use 1x1x4 IDP4A sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<1,1,4>, - 1, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - arch::OpMultiplyAdd>; - - Mma mma; + ArchMmaOperator mma; // Compute matrix product CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) { + for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Shape::kN; ++n) { @@ -143,8 +145,8 @@ struct Mma< mma( tmp, - ptr_A[m * Shape::kK / Mma::Shape::kK + k], - ptr_B[n * Shape::kK / Mma::Shape::kK + k], + ptr_A[m * Shape::kK / ArchMmaOperator::Shape::kK + k], + ptr_B[n * Shape::kK / ArchMmaOperator::Shape::kK + k], tmp); d.at(mn) = reinterpret_cast(tmp); @@ -206,6 +208,19 @@ struct Mma< /// C operand storage using FragmentC = Array; + /// Underlying matrix multiply operator (concept: arch::Mma) + /// Use 1x1x4 IDP4A sequence for bulk of computation + using ArchMmaOperator = arch::Mma< + gemm::GemmShape<1,1,4>, + 1, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + arch::OpMultiplyAdd>; + // // Methods // @@ -224,25 +239,15 @@ struct Mma< // Copy accumulators D = C; - /// Use 1x1x4 IDP4A sequence for bulk of computation - using Mma = arch::Mma< - gemm::GemmShape<1,1,4>, - 1, - ElementA, - LayoutA, - ElementB, - LayoutB, - ElementC, - LayoutC, - arch::OpMultiplyAdd>; - - Mma mma; + /// Underlying matrix multiply operator + ArchMmaOperator mma; + Array const *ptr_A = reinterpret_cast const *>(&A); Array const *ptr_B = reinterpret_cast const *>(&B); // Compute matrix product CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < Shape::kK / Mma::Shape::kK; ++k) { + for (int k = 0; k < Shape::kK / ArchMmaOperator::Shape::kK; ++k) { CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Shape::kN; ++n) { diff --git a/include/cutlass/gemm/threadblock/default_mma.h b/include/cutlass/gemm/threadblock/default_mma.h index 3ebe14e6b8..fbf76510db 100644 --- a/include/cutlass/gemm/threadblock/default_mma.h +++ b/include/cutlass/gemm/threadblock/default_mma.h @@ -36,6 +36,7 @@ #include "cutlass/layout/matrix.h" #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm50.h b/include/cutlass/gemm/threadblock/default_mma_core_sm50.h deleted file mode 100644 index 782cd7aea8..0000000000 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm50.h +++ /dev/null @@ -1,197 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data - layout of the global memory fragments, data types, and internal tile sizes. - - Partial specializations for threadblock::Mma operations targeting TensorOp instructions. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/array.h" - -#include "cutlass/numeric_types.h" -#include "cutlass/matrix_shape.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/transform/pitch_linear_thread_map.h" -#include "cutlass/transform/threadblock/regular_tile_iterator.h" - -#include "cutlass/gemm/warp/mma_simt.h" -#include "cutlass/gemm/threadblock/default_mma_core.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization: -/// -/// A: column-major -/// B: row-major -/// InstructionShape: 1-by-1-by-1 -/// Operator: SIMT -/// -/// This uses the default warp-level operator given tile sizes -template < - /// Shape of threadblock-scoped matrix multiply operator (concept: - /// GemmShape) - typename Shape_, - /// Shape of warp-level matrix multiply operator (concept: GemmShape) - typename WarpShape_, - /// Data type of A operand - typename ElementA_, - /// Data type of B operand - typename ElementB_, - /// Data type of accumulator - typename ElementC_, - /// Layout of accumulator - typename LayoutC_, - /// Operation performed by GEMM - typename Operator_> -struct DefaultMmaCore, ElementA_, - layout::ColumnMajor, ElementB_, layout::RowMajor, - ElementC_, LayoutC_, arch::OpClassSimt, 2, Operator_, - > { - using Shape = Shape_; - using WarpShape = WarpShape_; - using InstructionShape = InstructionShape_; - using ElementA = ElementA_; - using LayoutA = layout::ColumnMajor; - using ElementB = ElementB_; - using LayoutB = layout::RowMajor; - using ElementC = ElementC_; - using LayoutC = LayoutC_; - using OperatorClass = arch::OpClassSimt; - - /// Number of warps present - using WarpCount = GemmShape< - Shape::kM / WarpShape::kM, - Shape::kN / WarpShape::kN, - Shape::kK / WarpShape::kK - >; - - // Divisility requirements - static_assert( - !(Shape::kM % WarpShape::kM) && - !(Shape::kN % WarpShape::kN), - "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." - ); - - /// Number of threads per warp - static int const kWarpSize = warp::WarpSize::value; - - /// Number of threads total - static int const kThreads = WarpCount::kCount * kWarpSize; - - // - // Shared memory layouts - // - - /// Shared memory layout for A operand - using SmemLayoutA = layout::ColumnMajor; - - /// Shared memory layout for B operand - using SmemLayoutB = layout::RowMajor; - - // - // Iterators to write to shared memory - // - - /// ThreadMap of iterator A - using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - 1 - >; - - /// Shared memory iterator to A operand - using SmemIteratorA = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementA, - SmemLayoutA, - 1, - IteratorThreadMapA - >; - - /// ThreadMap of iterator B - using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< - layout::PitchLinearShape, - kThreads, - 1 - >; - - /// Shared memory iterator to B operand - using SmemIteratorB = transform::threadblock::RegularTileIterator< - MatrixShape, - ElementB, - SmemLayoutB, - 0, - IteratorThreadMapB - >; - - // - // Warp-level matrix multiply operator - // - - // Define the warp-level tensor op - using WarpMma = cutlass::gemm::warp::MmaSimt< - WarpShape, - ElementA, - SmemLayoutA, - ElementB, - SmemLayoutB, - ElementC, - LayoutC, - warp::MmaSimtPolicy< - MatrixShape<4, 8>, - layout::RowMajorInterleaved<2>, - GemmShape< - 128 / sizeof_bits::value, - 128 / sizeof_bits::value, - 1> - > - > - >; - - /// Policy used to define MmaPipelined - using MmaPolicy = MmaPolicy< - WarpMma, - MatrixShape<0, 0>, - MatrixShape<0, 0>, - WarpCount::kK - >; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h index e7a2adcb14..d797704e79 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm75.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm75.h @@ -1119,11 +1119,18 @@ struct DefaultMmaCore(m, n) is mapped to Column/RowMajor(m +/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators +/// can be reused. The shared store iterator is the same as the crosswise shared +/// store iterator. So, the only thing we need to do is to swap the coordinates +/// (contiguous <=> strided) used by the global iterator and the shared store +/// iterator. template < /// Shape of threadblock-scoped matrix multiply operator (concept: /// GemmShape) diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h index d9b3d9a0c5..065ed74694 100644 --- a/include/cutlass/gemm/threadblock/default_mma_core_sm80.h +++ b/include/cutlass/gemm/threadblock/default_mma_core_sm80.h @@ -1362,6 +1362,13 @@ struct DefaultMmaCore(m, n) is mapped to Column/RowMajor(m +/// x InterleavedK, n / InterleavedK) so that Column/RowMajor global iterators +/// can be reused. The shared store iterator is the same as the crosswise shared +/// store iterator. So, the only thing we need to do is to swap the coordinates +/// (contiguous <=> strided) used by the global iterator and the shared store +/// iterator. template < /// Shape of threadblock-scoped matrix multiply operator (concept: /// GemmShape) @@ -1608,7 +1615,7 @@ struct DefaultMmaCore; - /// Transpose the ThreadMap of iterator A + /// Transpose the ThreadMap of iterator B using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; /// Shared memory iterator to B operand @@ -1916,7 +1923,7 @@ struct DefaultMmaCore; - /// Transpose the ThreadMap of iterator A + /// Transpose the ThreadMap of iterator B using SmemThreadMapB = transform::TransposePitchLinearThreadMapSimt; /// Shared memory iterator to B operand diff --git a/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h b/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h new file mode 100644 index 0000000000..f7298e4e7e --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h @@ -0,0 +1,828 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming + expectations about data layout of the global memory fragments, data types, + and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting sparse + TensorOp instructions. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" + +#include "cutlass/gemm/threadblock/default_mma_core.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass/gemm/threadblock/mma_sparse_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false + /// Cache operation of operand A + , cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global +> +struct DefaultSparseMmaCore; + +//////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: column-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultSparseMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static int const kSparse = 2; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementA))>; + + // Shared memory layout + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementB))>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Cache operation of operand E + static cutlass::arch::CacheOperation::Kind const kCacheOpE = + cutlass::arch::CacheOperation::Global; + + static int const kInterleavedE = MmaTensorOp::kInterleaved; + static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; + static int const kMaxID2 = MmaTensorOp::kMaxID2; + static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; + + using ElementE = typename MmaTensorOp::ElementE; + using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; + + // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. + using SmemLayoutE = typename MmaTensorOp::LayoutE; + + /// ThreadMap of iterator E + static int const kElementsPerAccessE = + kAccessSizeInBits / sizeof_bits::value; + + /// E is tiny. Not all warps are needed. + static int const kThreadsE = + (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value) > + kThreads) + ? kThreads + : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value)); + + using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreadsE, kElementsPerAccessE>; + + /// Shared memory iterator to E operand + using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< + MatrixShape, + ElementE, SmemLayoutE, 0, IteratorThreadMapE>; + + /// Policy used to define MmaPipelined + using MmaPolicy = + SparseMmaPolicy, MatrixShape<0, 0>, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultSparseMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static int const kSparse = 2; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + // crosswise cannot be larger than 1024 bit. + static int const kCrosswiseB = + (Shape::kK > (1024 / sizeof_bits::value)) + ? (1024 / sizeof_bits::value) + : Shape::kK; + + static int const kWarpThreadArrangementContiguousB = + kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK / kSparse>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, kCrosswiseB>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Cache operation of operand E + static cutlass::arch::CacheOperation::Kind const kCacheOpE = + cutlass::arch::CacheOperation::Global; + + static int const kInterleavedE = MmaTensorOp::kInterleaved; + static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; + static int const kMaxID2 = MmaTensorOp::kMaxID2; + static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; + + using ElementE = typename MmaTensorOp::ElementE; + using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; + + // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. + using SmemLayoutE = typename MmaTensorOp::LayoutE; + + /// ThreadMap of iterator E + static int const kElementsPerAccessE = + kAccessSizeInBits / sizeof_bits::value; + + /// E is tiny. Not all warps are needed. + static int const kThreadsE = + (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value) > + kThreads) + ? kThreads + : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value)); + + using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreadsE, kElementsPerAccessE>; + + + /// Shared memory iterator to E operand + using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< + MatrixShape, + ElementE, SmemLayoutE, 0, IteratorThreadMapE>; + + /// Policy used to define MmaPipelined + using MmaPolicy = + SparseMmaPolicy, MatrixShape<0, 0>, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: column-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultSparseMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + + using LayoutA = layout::ColumnMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static int const kSparse = 2; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + // crosswise cannot be larger than 1024 bit. + static int const kCrosswiseB = + (Shape::kK > (1024 / sizeof_bits::value)) + ? (1024 / sizeof_bits::value) + : Shape::kK; + + static int const kWarpThreadArrangementContiguousB = + kCrosswiseB / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementA))>; + + // Shared memory layout + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, kCrosswiseB>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 1, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Cache operation of operand E + static cutlass::arch::CacheOperation::Kind const kCacheOpE = + cutlass::arch::CacheOperation::Global; + + static int const kInterleavedE = MmaTensorOp::kInterleaved; + static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; + static int const kMaxID2 = MmaTensorOp::kMaxID2; + static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; + + using ElementE = typename MmaTensorOp::ElementE; + using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; + + // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. + using SmemLayoutE = typename MmaTensorOp::LayoutE; + + /// ThreadMap of iterator E + static int const kElementsPerAccessE = + kAccessSizeInBits / sizeof_bits::value; + + /// E is tiny. Not all warps are needed. + static int const kThreadsE = + (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value) > + kThreads) + ? kThreads + : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value)); + + using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreadsE, kElementsPerAccessE>; + + /// Shared memory iterator to E operand + using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< + MatrixShape, + ElementE, SmemLayoutE, 0, IteratorThreadMapE>; + + /// Policy used to define MmaPipelined + using MmaPolicy = + SparseMmaPolicy, MatrixShape<0, 0>, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultSparseMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::RowMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static int const kSparse = 2; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / kSparse / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK / kSparse>; + + // Shared memory layout + using SmemLayoutB = layout::RowMajorTensorOpMultiplicandCongruous< + sizeof_bits::value, int(128 / sizeof(ElementB))>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape<8, 4>, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 0, + IteratorThreadMapB>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Cache operation of operand E + static cutlass::arch::CacheOperation::Kind const kCacheOpE = + cutlass::arch::CacheOperation::Global; + + static int const kInterleavedE = MmaTensorOp::kInterleaved; + static int const kMetaSizeInBits = MmaTensorOp::kMetaSizeInBits; + static int const kMaxID2 = MmaTensorOp::kMaxID2; + static int const kElementsPerElementE = MmaTensorOp::kElementsPerElementE; + + using ElementE = typename MmaTensorOp::ElementE; + using GmemLayoutE = cutlass::layout::ColumnMajorInterleaved; + + // Shared memory layout. Interleaved layout is mapped to PitchLinear layout. + using SmemLayoutE = typename MmaTensorOp::LayoutE; + + /// ThreadMap of iterator E + static int const kElementsPerAccessE = + kAccessSizeInBits / sizeof_bits::value; + + /// E is tiny. Not all warps are needed. + static int const kThreadsE = + (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value) > + kThreads) + ? kThreads + : (Shape::kM * Shape::kK / kSparse / kElementsPerElementE / + (kAccessSizeInBits / sizeof_bits::value)); + + using IteratorThreadMapE = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreadsE, kElementsPerAccessE>; + + /// Shared memory iterator to E operand + using SmemIteratorE = transform::threadblock::RegularTileAccessIterator< + MatrixShape, + ElementE, SmemLayoutE, 0, IteratorThreadMapE>; + + /// Policy used to define MmaPipelined + using MmaPolicy = + SparseMmaPolicy, MatrixShape<0, 0>, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/include/cutlass/gemm/threadblock/default_sparse_mma.h b/include/cutlass/gemm/threadblock/default_sparse_mma.h new file mode 100644 index 0000000000..3f6354771e --- /dev/null +++ b/include/cutlass/gemm/threadblock/default_sparse_mma.h @@ -0,0 +1,190 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/gemm/threadblock/default_mma_core_wmma.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false + > +struct DefaultSparseMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator + > +struct DefaultSparseMma { + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + static int const kSparse = MmaCore::kSparse; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define iterators over tiles from the E operand + using ElementE = typename MmaCore::ElementE; + using LayoutE = typename MmaCore::GmemLayoutE; + using ThreadMapE = typename MmaCore::IteratorThreadMapE; + using AccessTypeE = + cutlass::Array::value>; + using IteratorE = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementE, LayoutE, 1, ThreadMapE, AccessTypeE>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::SparseMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, ElementAccumulator, layout::RowMajor, + IteratorE, typename MmaCore::SmemIteratorE, MmaCore::kCacheOpE, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_sparse_base.h b/include/cutlass/gemm/threadblock/mma_sparse_base.h new file mode 100644 index 0000000000..c6bb3411fc --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_sparse_base.h @@ -0,0 +1,259 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy object describing MmaTensorOp +template < + /// Warp-level GEMM operator (concept: gemm::warp::Mma) + typename Operator_, + /// Padding used for A operand in shared memory (concept: MatrixShape) + typename SmemPaddingA_, + /// Padding used for B operand in shared memory (concept: MatrixShape) + typename SmemPaddingB_, + /// Padding used for E operand in shared memory (concept: MatrixShape) + typename SmemPaddingE_, + /// Number of partitions of K dimension of GEMM + int PartitionsK = 1> +struct SparseMmaPolicy { + /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) + using Operator = Operator_; + + /// Padding used for A operand in shared memory + using SmemPaddingA = SmemPaddingA_; + + /// Padding used for B operand in shared memory + using SmemPaddingB = SmemPaddingB_; + + /// Padding used for B operand in shared memory + using SmemPaddingE = SmemPaddingE_; + + /// Number of partitions of K dimension + static int const kPartitionsK = PartitionsK; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class SparseMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + static int const kSparse = Operator::kSparse; + + static int const kElementsPerElementE = Operator::kElementsPerElementE; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + /// Tensor reference to the E operand + using TensorRefE = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the E matrix operand in shared memory + using ShapeE = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for E operand + AlignedBuffer operand_E; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a layout object for the E matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutE LayoutE() { + return Operator::LayoutE::packed({ShapeE::kRow, ShapeE::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the E operand + CUTLASS_HOST_DEVICE + TensorRefE operand_E_ref() { + return TensorRefE{operand_E.data(), LayoutE()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + /// Iterator to load a warp-scoped tile of E operand from shared memory + typename Operator::IteratorE warp_tile_iterator_E_; + + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + SparseMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_E_(shared_storage.operand_E_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h new file mode 100644 index 0000000000..a2ff84664a --- /dev/null +++ b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h @@ -0,0 +1,667 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_sparse_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Iterates over tiles of E operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorE_, + /// Iterates over tiles of E operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorE_, + /// Cache operation for operand E + cutlass::arch::CacheOperation::Kind CacheOpE, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class SparseMmaMultistage : + public SparseMmaBase { +public: + ///< Base class + using Base = SparseMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Iterates over tiles of E operand in global memory + using IteratorE = IteratorE_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorE = SmemIteratorE_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + static cutlass::arch::CacheOperation::Kind const kCacheOpE = CacheOpE; + + static int const kSparse = Policy::Operator::kSparse; + static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; + static int const kMaxID2 = Policy::Operator::kMaxID2; + static int const kElementsPerElementE = + Policy::Operator::kElementsPerElementE; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// ElementE + using ElementE = typename IteratorE::Element; + + /// LayoutE + using LayoutE = typename IteratorE::Layout; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of async copies to load one stage of operand A + static int const TBLDGSTSIterationsA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of async copies to load one stage of operand B + static int const TBLDGSTSIterationsB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of async copies to load one stage of operand E + static int const TBLDGSTSIterationsE = + IteratorE::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of async copies to load one group of operand A + static int const kAccessesPerGroupA = + (TBLDGSTSIterationsA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of async copies to load one group of operand B + static int const kAccessesPerGroupB = + (TBLDGSTSIterationsB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of async copies to load one group of operand E + static int const kAccessesPerGroupE = + (TBLDGSTSIterationsE + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// E operand is tiny. For the most of time, not all the warps are needed + /// to load it from the global memory. + static int const kValidWarps = IteratorE::ThreadMap::kThreads / 32; + + /// B operand is twice as big as A which brings very high register pressure. + /// We have to sacrifice the double buffer when the warp tile size is big. + static int const kBBufferSize = + ((sizeof(typename Operator::ElementC) == 4) && + ((platform::is_same::value && + platform::is_same::value)) && + (Operator::Shape::kM >= 64 && Operator::Shape::kN >= 64)) + ? 1 + : 2; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + using WarpFragmentE = typename Operator::FragmentE; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of E operand to shared memory + SmemIteratorE smem_iterator_E_; + + /// Warp id + bool is_warp_valid_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + SparseMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_E_(shared_storage.operand_E_ref(), thread_idx) + { + is_warp_valid_ = warp_idx < Detail::kValidWarps; + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_E_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + IteratorE &iterator_E, int group_start_A = 0, + int group_start_B = 0, int group_start_E = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // async copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::TBLDGSTSIterationsA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // async copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::TBLDGSTSIterationsB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + + iterator_E.set_iteration_index(group_start_E); + this->smem_iterator_E_.set_iteration_index(group_start_E); + + // async copy for operand E + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupE; ++j) { + if (group_start_E + j < Detail::TBLDGSTSIterationsE) { + typename IteratorE::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_E_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorE::ThreadMap::kElementsPerAccess / 8; + + auto gmem_ptr = iterator_E.get(); + + cutlass::arch::cp_async( + dst_ptr, gmem_ptr, iterator_E.valid() && is_warp_valid_); + + ++iterator_E; + ++this->smem_iterator_E_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over E operand in global memory + IteratorE iterator_E, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + iterator_E.clear_mask(); + } + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // async copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // async copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + iterator_E.set_iteration_index(0); + this->smem_iterator_E_.set_iteration_index(0); + + // async copy for operand E + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLDGSTSIterationsE; ++j) { + typename IteratorE::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_E_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorE::ThreadMap::kElementsPerAccess / 8; + if (is_warp_valid_) + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_E.get(), iterator_E.valid()); + + ++iterator_E; + + ++this->smem_iterator_E_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_E.add_tile_offset({0, 1}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + this->smem_iterator_E_.add_tile_offset({0, 1}); + + // LDGDEPBAR - completes a stage + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[Detail::kBBufferSize]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[Detail::kBBufferSize]; + WarpFragmentE warp_frag_E[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_E_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + this->warp_tile_iterator_E_.load(warp_frag_E[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + ++this->warp_tile_iterator_E_; + + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + iterator_E.clear_mask(); + } + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_E_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_E_.load(warp_frag_E[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_E_; + + if (Detail::kBBufferSize == 2) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % Detail::kBBufferSize]); + + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % Detail::kBBufferSize], accum, + warp_frag_E[warp_mma_k % 2] + ); + + if (Detail::kBBufferSize == 1) { + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + ++this->warp_tile_iterator_B_; + + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + group_start_iteration_E = warp_mma_k * Detail::kAccessesPerGroupE; + + copy_tiles_and_advance( + iterator_A, iterator_B, iterator_E, group_start_iteration_A, + group_start_iteration_B, group_start_iteration_E); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B, group_start_iteration_E; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + group_start_iteration_E = + (warp_mma_k + 1) * Detail::kAccessesPerGroupE; + + copy_tiles_and_advance( + iterator_A, iterator_B, iterator_E, group_start_iteration_A, + group_start_iteration_B, group_start_iteration_E); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_E.add_tile_offset({0, 1}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + this->smem_iterator_E_.add_tile_offset({0, 1}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_E_.add_tile_offset({0, -Base::kStages}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + this->warp_tile_iterator_E_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + if (gemm_k_iterations == 0) { + iterator_A.clear_mask(); + iterator_B.clear_mask(); + iterator_E.clear_mask(); + } + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle.h b/include/cutlass/gemm/threadblock/threadblock_swizzle.h index 03d71d3197..587de56a66 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle.h @@ -123,16 +123,22 @@ struct GemmIdentityThreadblockSwizzle { /// Computes CUDA grid dimensions given a size in units of logical tiles CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const { + if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) + return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); + return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k()); } /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE - GemmCoord get_tile_offset() const { + GemmCoord get_tile_offset(GemmCoord tiled_shape) const { int block_idx_x = RematerializeBlockIdxX(); int block_idx_y = RematerializeBlockIdxY(); + if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) + return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; + return GemmCoord{ (block_idx_x / kTile), (block_idx_y * kTile) + (block_idx_x % kTile), @@ -170,7 +176,7 @@ struct GemmHorizontalThreadblockSwizzle { /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE - GemmCoord get_tile_offset() const { + GemmCoord get_tile_offset(GemmCoord tiled_shape) const { return GemmCoord{ RematerializeBlockIdxY(), RematerializeBlockIdxX(), @@ -205,7 +211,7 @@ struct GemmBatchedIdentityThreadblockSwizzle { /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE - GemmCoord get_tile_offset() const { + GemmCoord get_tile_offset(GemmCoord tiled_shape) const { return GemmCoord{ RematerializeBlockIdxX(), RematerializeBlockIdxY(), @@ -244,17 +250,23 @@ struct GemmSplitKIdentityThreadblockSwizzle { /// Computes CUDA grid dimensions given a size in units of logical tiles CUTLASS_HOST_DEVICE dim3 get_grid_shape(GemmCoord tiled_shape) const { + if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) + return dim3(tiled_shape.m(), tiled_shape.n(), tiled_shape.k()); + return dim3(tiled_shape.m() * kTile, (tiled_shape.n() + kTile - 1) / kTile, tiled_shape.k()); } /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE - GemmCoord get_tile_offset() const { + GemmCoord get_tile_offset(GemmCoord tiled_shape) const { int block_idx_x = RematerializeBlockIdxX(); int block_idx_y = RematerializeBlockIdxY(); + if ((tiled_shape.m() < kTile) || (tiled_shape.n() < kTile)) + return GemmCoord{block_idx_x, block_idx_y, RematerializeBlockIdxZ()}; + return GemmCoord{ (block_idx_x / kTile), (block_idx_y * kTile) + (block_idx_x % kTile), @@ -290,7 +302,7 @@ struct GemmSplitKHorizontalThreadblockSwizzle { /// Obtains the threadblock offset (in units of threadblock-scoped tiles) CUTLASS_DEVICE - GemmCoord get_tile_offset() const { + GemmCoord get_tile_offset(GemmCoord tiled_shape) const { return GemmCoord{ RematerializeBlockIdxY(), RematerializeBlockIdxX(), diff --git a/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h b/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h new file mode 100644 index 0000000000..637e39009e --- /dev/null +++ b/include/cutlass/gemm/warp/default_mma_sparse_tensor_op.h @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/mma_sparse_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false +> +struct DefaultSparseMmaTensorOp; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial Specialization - inputs and output types are float - uses TF32 internally +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of target matrix multiply instruction (concept: GemmShape) + typename InstructionShape_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultSparseMmaTensorOp< + WarpShape_, + InstructionShape_, + float, LayoutA, + float, LayoutB, + float, LayoutC, + arch::OpMultiplyAdd, PartitionsK, AccumulatorsInRowMajor> { + + // Uses TF32 internally + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::SparseMma< + InstructionShape_, + 32, + tfloat32_t, cutlass::layout::RowMajor, + tfloat32_t, cutlass::layout::ColumnMajor, + float, cutlass::layout::RowMajor, + arch::OpMultiplyAdd + >, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::SparseMmaTensorOp< + WarpShape_, float, LayoutA, float, LayoutB, float, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultSparseMmaTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::SparseMma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::SparseMmaTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_simt.h b/include/cutlass/gemm/warp/mma_simt.h index 1bf23c7432..c90624cee7 100644 --- a/include/cutlass/gemm/warp/mma_simt.h +++ b/include/cutlass/gemm/warp/mma_simt.h @@ -147,6 +147,9 @@ class MmaSimt { dp4a_type >; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; + /// Shape of the underlying instruction using InstructionShape = GemmShape<1,1,use_dp4a ? 4 : 1>; diff --git a/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h new file mode 100644 index 0000000000..8b7312baa0 --- /dev/null +++ b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h @@ -0,0 +1,335 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate + operations targeting sparse Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class SparseMmaTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Architecture tag from underlying instruction + using ArchTag = typename Policy::Operator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + /// Sparsity in Operand A + static int const kSparse = Policy::Operator::kSparse; + + /// Meta data size in bits + static int const kMetaSizeInBits = Policy::Operator::kMetaSizeInBits; + + /// Max ID2 + static int const kMaxID2 = Policy::Operator::kMaxID2; + + /// Data type of meta E that is moved at the same time + using ElementE = + typename cutlass::platform::conditional::type; + + /// Number of ElementA that is associated with one ElementE + static int const kElementsPerElementE = + 128 / cutlass::sizeof_bits::value; + + /// Meta data is essentially interleaved but mapped to ColumnMajor internally + static int const kInterleaved = 2; + + /// Layout of meta E + using LayoutE = cutlass::layout::ColumnMajor; + + public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, + LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename Policy::Operator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Iterates over the E operand in memory + using IteratorE = SparseMmaTensorOpMetaTileIterator< + MatrixShape, + ElementE, LayoutE, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for E tile + using FragmentE = typename IteratorE::Fragment; + +private: + + static_assert( + !(Shape::kM % Policy::Operator::Shape::kM) && + !(Shape::kN % Policy::Operator::Shape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + Shape::kM / Policy::Operator::Shape::kM, + Shape::kN / Policy::Operator::Shape::kN + >; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + typename Policy::Operator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + SparseMmaTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C, + FragmentE const &E + ) const { + + using MmaOperandA = typename Policy::Operator::FragmentA; + using MmaOperandB = typename Policy::Operator::FragmentB; + using MmaOperandC = typename Policy::Operator::FragmentC; + using MmaOperandE = typename Policy::Operator::FragmentE; + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + MmaOperandE const *ptr_E = reinterpret_cast(&E); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int id2 = m % kMaxID2; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_E[(m / kMaxID2)], + id2); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_E[(m / kMaxID2)], + id2); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // + // Define conversions from source type to instruction type + // + FloatRoundStyle const kRoundA = + PreferredRoundingMode::kRound; + FloatRoundStyle const kRoundB = + PreferredRoundingMode::kRound; + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array * + ptr_dst_A = reinterpret_cast *>(&dst_A); + + dst_B = convert_B(B); + + ptr_dst_A[0] = convert_A(ptr_A[0]); + ptr_dst_A[1] = convert_A(ptr_A[1]); + #else + assert(0); + #endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_tensor_op.h b/include/cutlass/gemm/warp/mma_tensor_op.h index 3eff7b9054..1a10c7e4fe 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_tensor_op.h @@ -184,14 +184,17 @@ class MmaTensorOp { /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) using Policy = Policy_; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + /// Architecture tag from underlying instruction - using ArchTag = typename Policy::Operator::ArchTag; + using ArchTag = typename ArchMmaOperator::ArchTag; /// Indicates class of matrix operator using OperatorClass = arch::OpClassTensorOp; /// Shape of underlying instruction - using InstructionShape = typename Policy::Operator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; /// Complex transform on A operand static ComplexTransform const kTransformA = ComplexTransform::kNone; @@ -210,7 +213,7 @@ class MmaTensorOp { /// Iterates over the A operand in memory using IteratorA = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kA, ElementA, LayoutA, - MatrixShape, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for A tile @@ -218,12 +221,12 @@ class MmaTensorOp { /// Storage for transformed A tile using TransformedFragmentA = - Array; + Array; /// Iterates over the B operand in memory using IteratorB = MmaTensorOpMultiplicandTileIterator< MatrixShape, Operand::kB, ElementB, LayoutB, - MatrixShape, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; /// Storage for B tile @@ -231,33 +234,28 @@ class MmaTensorOp { /// Storage for transformed B tile using TransformedFragmentB = - Array; + Array; /// Iterates over the C operand in memory using IteratorC = MmaTensorOpAccumulatorTileIterator< MatrixShape, ElementC, LayoutC, - typename Policy::Operator::Shape, typename Policy::OpDelta>; + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; /// Storage for C tile using FragmentC = typename IteratorC::Fragment; private: - static_assert( - !(Shape::kM % Policy::Operator::Shape::kM) && - !(Shape::kN % Policy::Operator::Shape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); - /// Number of mma operations performed using MmaIterations = MatrixShape< - Shape::kM / Policy::Operator::Shape::kM, - Shape::kN / Policy::Operator::Shape::kN + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN >; public: /// Underlying matrix multiply operator (concept: arch::Mma) - typename Policy::Operator mma; + ArchMmaOperator mma; public: @@ -278,9 +276,9 @@ class MmaTensorOp { FragmentC const &C ) const { - using MmaOperandA = typename Policy::Operator::FragmentA; - using MmaOperandB = typename Policy::Operator::FragmentB; - using MmaOperandC = typename Policy::Operator::FragmentC; + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; D = C; @@ -351,22 +349,22 @@ class MmaTensorOp { // Define conversions from source type to instruction type // FloatRoundStyle const kRoundA = - PreferredRoundingMode::kRound; FloatRoundStyle const kRoundB = - PreferredRoundingMode::kRound; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - detail::ConvertAndPack convert_A; - NumericArrayConverter convert_B; Array const *ptr_B = reinterpret_cast const *>(&B); - Array * - ptr_dst_B = reinterpret_cast * + ptr_dst_B = reinterpret_cast *>(&dst_B); dst_A = convert_A(A); @@ -375,16 +373,16 @@ class MmaTensorOp { ptr_dst_B[1] = convert_B(ptr_B[1]); #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - detail::ConvertAndPack convert_A; - NumericArrayConverter convert_B; Array const *ptr_A = reinterpret_cast const *>(&A); - Array * - ptr_dst_A = reinterpret_cast * + ptr_dst_A = reinterpret_cast *>(&dst_A); dst_B = convert_B(B); diff --git a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h index 85f5009d8c..5b5b5345a0 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h @@ -291,7 +291,9 @@ class MmaTensorOpFragmentIterator; + /// Number of Accesses in a warp + using AccessIterations = MatrixShape; + /// Number of K iterations static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; - static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn - * (AccumulatorShape::kRow / Shape::kRow); - static int const kResidualIndex = kResidualColumn / Shape::kColumn - * (AccumulatorShape::kRow / Shape::kRow); + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn; + static int const kResidualIndex = kResidualColumn / Shape::kColumn; public: @@ -338,8 +351,8 @@ class MmaTensorOpFragmentIterator; - using FragmentAccessType = Array; + using AccessType = Array; + using FragmentAccessType = Array; private: // @@ -386,6 +399,11 @@ class MmaTensorOpFragmentIterator(&frag); // NumericArrayConverter fragmentConverter; - int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; - int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow - * MmaIterations::kColumn; + int index = index_ * AccessIterations::kCount; CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; m++) { - for (int n = 0; n < MmaIterations::kColumn; n++) { - int accumulator_access_offset = - (m + index_m) * AccumulatorIterations::kColumn + n + index_n; - - frag_ptr[m * MmaIterations::kColumn + n].clear(); + for (int i = 0; i < AccessIterations::kCount; i++) { +// int index_m = (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) +// * kIterationsPerInstruction + index % kIterationsPerInstruction; +// +// int index_n = (index / AccessIterations::kCount) * MmaIterations::kColumn + +// (index % (AccessIterations::kColumn * kIterationsPerInstruction)) +// / kIterationsPerInstruction * AccessIterations::kColumn; +// +// int accumulator_access_offset = index_m / kIterationsPerInstruction * AccessIterations::kCount * kIterationsPerInstruction +// + index_m % kIterationsPerInstruction + index_n * kIterationsPerInstruction; + + int accumulator_access_offset = index / AccessIterations::kCount * (MmaIterations::kColumn * kIterationsPerInstruction) + + (index % AccessIterations::kCount) / (AccessIterations::kColumn * kIterationsPerInstruction) * + AccumulatorIterations::kColumn * kIterationsPerInstruction + + (index % (AccessIterations::kColumn * kIterationsPerInstruction)) / kIterationsPerInstruction * + (kIterationsPerInstruction * kIterationsPerAccess) + + (index % kIterationsPerInstruction); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kIterationsPerAccess; j++) { + + frag_ptr[i*kIterationsPerAccess + j].clear(); if(!(is_residual_tile_ && index_ >= kResidualIndex)) -// frag_ptr[m * MmaIterations::kColumn + n] = fragmentConverter(accumulators_[accumulator_access_offset]); - frag_ptr[m * MmaIterations::kColumn + n] = output_op(accumulators_[accumulator_access_offset], src_fragment); + // frag_ptr[m * MmaIterations::kColumn + n] = fragmentConverter(accumulators_[accumulator_access_offset]); + frag_ptr[i*kIterationsPerAccess + j] = output_op(accumulators_[accumulator_access_offset + j * kAccessStride], src_fragment); } + index++; } } diff --git a/include/cutlass/gemm/warp/mma_tensor_op_sm70.h b/include/cutlass/gemm/warp/mma_tensor_op_sm70.h index 063c77f9cc..cc1a909532 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_sm70.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_sm70.h @@ -106,8 +106,11 @@ class MmaVoltaTensorOp { /// Architecture tag using ArchTag = arch::Sm70; + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + /// Underlying instruction shape - using InstructionShape = typename Policy::Operator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; /// Complex transform on A operand static ComplexTransform const kTransformA = ComplexTransform::kNone; @@ -133,8 +136,8 @@ class MmaVoltaTensorOp { ElementA, LayoutA, MatrixShape< - Policy::Operator::Shape::kM, - Policy::Operator::Shape::kK + ArchMmaOperator::Shape::kM, + ArchMmaOperator::Shape::kK >, Policy::OpDelta::kRow, kThreadCount @@ -150,8 +153,8 @@ class MmaVoltaTensorOp { ElementB, LayoutB, MatrixShape< - Policy::Operator::Shape::kK, - Policy::Operator::Shape::kN + ArchMmaOperator::Shape::kK, + ArchMmaOperator::Shape::kN >, Policy::OpDelta::kRow, kThreadCount @@ -165,7 +168,7 @@ class MmaVoltaTensorOp { MatrixShape, ElementC, LayoutC, - typename Policy::Operator::Shape, + typename ArchMmaOperator::Shape, typename Policy::OpDelta >; @@ -175,14 +178,14 @@ class MmaVoltaTensorOp { private: static_assert( - !(Shape::kM % Policy::Operator::Shape::kM) && - !(Shape::kN % Policy::Operator::Shape::kN), + !(Shape::kM % ArchMmaOperator::Shape::kM) && + !(Shape::kN % ArchMmaOperator::Shape::kN), "Shape of warp-level Mma must be divisible by operator shape."); /// Number of mma operations performed using MmaIterations = MatrixShape< - InterleavedTileShape::kM / Policy::Operator::Shape::kM, - InterleavedTileShape::kN / Policy::Operator::Shape::kN + InterleavedTileShape::kM / ArchMmaOperator::Shape::kM, + InterleavedTileShape::kN / ArchMmaOperator::Shape::kN >; using TileIterations = MatrixShape< Shape::kM / InterleavedTileShape::kM, @@ -195,7 +198,7 @@ class MmaVoltaTensorOp { public: /// Underlying matrix multiply operator (concept: arch::Mma) - typename Policy::Operator mma; + ArchMmaOperator mma; public: @@ -215,9 +218,9 @@ class MmaVoltaTensorOp { FragmentB const &B, FragmentC const &C) { - using MmaOperandA = typename Policy::Operator::FragmentA; - using MmaOperandB = typename Policy::Operator::FragmentB; - using MmaOperandC = typename Policy::Operator::FragmentC; + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; D = C; diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h index 1a8fa4f915..1fe04e92af 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h @@ -241,6 +241,9 @@ class MmaTensorOpMultiplicandTileIterator< int access_strided_idx = -1; if (Policy::LdsmShape::kContiguous == 4) { + // Matrix multiply 1688 A/B + // Q0 Q1 Q2 Q3 (Q stands for 1 8x128bit block). + // Four blocks are next to each other in the contiguous dimension. partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ i); access_contiguous_idx = (quad_pair ^ lane_in_quad); access_strided_idx = lane_in_quad_pair; @@ -262,7 +265,17 @@ class MmaTensorOpMultiplicandTileIterator< partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 1)); access_contiguous_idx = ((quad_quad + ((i & 1) << 1)) ^ lane_in_quad); access_strided_idx = lane_in_quad_quad; + } else if (Policy::LdsmShape::kContiguous == 1) { + // Matrix multiply 16832.SP B + // Q0 + // Q1 + // Q2 + // Q3 + partition_contiguous_idx = ((lane_in_quad_pair >> 2) ^ (i >> 2)); + access_contiguous_idx = ((i & 3) ^ lane_in_quad); + access_strided_idx = lane_id; } + int access_contiguous = partition_contiguous_idx * Layout::PartitionShape::kContiguous + access_contiguous_idx; @@ -531,24 +544,24 @@ class MmaTensorOpMultiplicandTileIterator< !(Shape::kContiguous % InstructionShape::kContiguous), "Shape of warp-level Mma must be divisible by operator shape."); - // Determine number of elements along outer dimension per individual LDS.32 - // op. Every one warp of LDS.32 loads 8x4 elements + // Determine number of elements along outer dimension per individual 32bit + // shared memory load op. Every one warp of 32bit shared memory load loads + // 8x4 elements static int const kLdsOpInner = Layout::TileShape::kStrided; static int const kLdsOpOuter = kThreads / kLdsOpInner; static_assert(!(Shape::kContiguous % kLdsOpOuter), - "Shape of warp-level mma must be divisible by LDS.32's " + "Shape of warp-level mma must be divisible by 32bit " "fundamental tile size."); static_assert(!(Shape::kStrided % kLdsOpInner), - "Shape of warp-level mma must be divisible by LDS.32's " + "Shape of warp-level mma must be divisible by 32bit " "fundamental tile size."); - /// Number of LDS.32 instructions needed by one MMA instruction - /// 1684 A 2x1 - /// 1684 B 1x1 - /// 1688 A 2x2 - /// 1688 B 1x2 + /// Number of 32 bit shared memory load instructions needed by one MMA instruction + /// 1688 A 2x2 + /// 1688 B 1x2 + /// 16816 B 1x4 static int const LdsShapeContiguous = InstructionShape::kContiguous / kLdsOpOuter; static int const LdsShapeStrided = InstructionShape::kStrided / kLdsOpInner; @@ -639,6 +652,8 @@ class MmaTensorOpMultiplicandTileIterator< if (Shape::kContiguous == Layout::TileShape::kContiguous * Layout::kElementsPerAccess / 2) { if (tile_offset.contiguous() % 2) { + // Matrix multiply 1688 pointer_[0] <=> pointer_[4] pointer_[1] <=> pointer_[5] + // pointer_[2] <=> pointer_[6] pointer_[3] <=> pointer_[7] CUTLASS_PRAGMA_UNROLL for (int i = 0; i < kPointerCount / 2; ++i) { AccessType const *tmp_pointer = pointer_[i]; @@ -1535,6 +1550,14 @@ class MmaTensorOpMultiplicandTileIterator< access_strided_idx = (lane_in_quad_pair + (lane_id >> 4 << 3)) / Layout::kFactor; } + else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { + // Matrix multiply 16832.SP B + // Q0 Q1 Q2 Q3 + partition_contiguous_idx = (lane_id % Layout::kFactor); + access_contiguous_idx = + (quad_pair ^ (lane_in_quad_pair / Layout::kFactor)); + access_strided_idx = lane_in_quad_pair / Layout::kFactor; + } } else if (Layout::kFactor == 1) { // Super Matrix multiply kBlock = 64 if (Policy::LdsmShape::kStrided == Policy::LdsmShape::kCount) { @@ -1565,6 +1588,13 @@ class MmaTensorOpMultiplicandTileIterator< access_contiguous_idx = ((quad_pair & 1) ^ lane_in_quad); access_strided_idx = lane_in_quad_pair + (lane_id >> 4 << 3); } + else if (Policy::LdsmShape::kContiguous == Policy::LdsmShape::kCount) { + // Matrix multiply 16832.SP B + // Q0 Q1 Q2 Q3 + partition_contiguous_idx = (lane_in_quad_pair >> 2); + access_contiguous_idx = (quad_pair ^ lane_in_quad); + access_strided_idx = lane_in_quad_pair; + } } int access_contiguous = @@ -2369,17 +2399,18 @@ class MmaTensorOpAccumulatorTileIterator< /// Internal structure of iterator - made public to enable introspection struct Policy { - static_assert( + static bool const kDivisible = !(Shape::kRow % InstructionShape::kM) && - !(Shape::kColumn % InstructionShape::kN), - "Shape of warp-level Mma must be divisible by operator shape."); + !(Shape::kColumn % InstructionShape::kN); static_assert(platform::is_same::value, "Layouts must be defined for logical MatrixCoord coordinate space."); /// Number of mma operations performed - using MmaIterations = MatrixShape; + using MmaIterations = MatrixShape< + (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, + (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN + >; }; private: @@ -2398,7 +2429,9 @@ class MmaTensorOpAccumulatorTileIterator< // /// Fragment object holding a thread's part of a tile - using Fragment = Array; + using Fragment = Array< + Element, + Policy::MmaIterations::kCount * InstructionShape::kMN / kThreads>; private: @@ -2667,17 +2700,18 @@ class MmaTensorOpAccumulatorTileIterator::value, "Layouts must be defined for logical MatrixCoord coordinate space."); /// Number of mma operations performed - using MmaIterations = MatrixShape; + using MmaIterations = MatrixShape< + (Shape::kRow + InstructionShape::kM - 1) / InstructionShape::kM, + (Shape::kColumn + InstructionShape::kN - 1) / InstructionShape::kN + >; }; private: @@ -2696,7 +2730,8 @@ class MmaTensorOpAccumulatorTileIterator; + using Fragment = Array; private: diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h index e43373b64f..e286ed1162 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h @@ -1570,6 +1570,839 @@ class MmaTensorOpMultiplicandTileIterator< } }; +//////////////////////////////////////////////////////////////////////////////// + + +/// Tile iterator specialized for canonical matrix layouts +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Shape of one matrix production operation (concept: MatrixShape) + typename InstructionShape_, + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + int OpDelta_, + /// Number of threads participating in one matrix operation + int Threads = 32, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class MmaTensorOpMultiplicandTileIteratorCanonical { + public: + + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + /// Basic check + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = Layout_; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 : 32 / sizeof_bits::value); + +private: + + static int const kWarpShapeOuter = + (kOperand == Operand::kA ? Shape::kRow : Shape::kColumn); + + static int const kWarpShapeInner = + (kOperand == Operand::kA ? Shape::kColumn : Shape::kRow); + + + /// Rounded up instruction counts + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn + >; + + /// Rounded up tile dimensions + using WarpShapeDivisible = MatrixShape< + InstructionCount::kRow * InstructionShape::kRow, + InstructionCount::kColumn * InstructionShape::kColumn + >; + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + WarpShapeDivisible::kRow * WarpShapeDivisible::kColumn / kThreads + >; + + /// Memory access type + using AccessType = AlignedArray; + +private: + + /// Underlying tensor reference + TensorRef ref_; + + /// Extent of tensor + MatrixCoord extent_; + + /// Origin + MatrixCoord origin_; + + /// Used to conditionally enable extents checking + bool divisible_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical(): divisible_(true) { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical( + TensorRef const &ref, + int lane_id + ): ref_(ref), extent_(Shape::kRow, Shape::kColumn), divisible_(true) { + + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); + } + else { + origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); + } + + ref_.add_coord_offset(origin_); + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical( + TensorRef const &ref, + TensorCoord extent, + int lane_id + ): ref_(ref), extent_(extent), divisible_(false) { + + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id / 4, (lane_id % 4) * kElementsPerAccess); + } + else { + origin_ = MatrixCoord((lane_id % 4) * kElementsPerAccess, lane_id / 4); + } + + ref_.add_coord_offset(origin_); + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical &add_pointer_offset(LongIndex offset) { + + ref_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical &add_tile_offset(TensorCoord const &tile_offset) { + + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical & operator++() { + + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } + else { + add_tile_offset({1, 0}); + } + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical & operator--() { + + if (kOperand == Operand::kA) { + add_tile_offset({0, -1}); + } + else { + add_tile_offset({-1, 0}); + } + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIteratorCanonical & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + load_with_pointer_offset(frag, 0); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + + int const kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? WarpShapeDivisible::kColumn : WarpShapeDivisible::kRow); + + // Take advantage of Tensor Op's 8 x 4T access pattern + int const kAccessesInner = (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + + AccessType *access_ptr = reinterpret_cast(&frag); + + if (kOperand == Operand::kA) { + int const kTilesPerInstruction = InstructionShape::kRow / 8; + + CUTLASS_PRAGMA_UNROLL + for (int inst_m_idx = 0; inst_m_idx < InstructionCount::kRow; ++inst_m_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < kTilesPerInstruction; ++access_m_idx) { + int access_idx = + access_m_idx + kTilesPerInstruction * (inner_idx + kAccessesInner * inst_m_idx); + + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + + MatrixCoord access_coord = origin_ + offset; + + if (divisible_ || + (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { + + access_ptr[access_idx] = *reinterpret_cast( + ref_.data() + ref_.offset(offset)); + } + else { + AccessType zero; + zero.clear(); + access_ptr[access_idx] = zero; + } + } + } + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; ++inst_n_idx) { + + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, + inst_n_idx * 8); + + MatrixCoord access_coord = origin_ + offset; + + if (divisible_ || + (access_coord.row() < extent_.row() && access_coord.column() < extent_.column())) { + + access_ptr[access_idx] = *reinterpret_cast( + ref_.data() + ref_.offset(offset)); + } + else { + AccessType zero; + zero.clear(); + access_ptr[access_idx] = zero; + } + } + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + + load_with_pointer_offset(frag, byte_offset * 8 / sizeof_bits::value); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + + load_with_pointer_offset(frag, ref_.offset(coord_offset)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + + load_with_pointer_offset(frag, ref_.offset(coord_offset) + pointer_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + + TensorCoord coord_offset(tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + + load_with_pointer_offset(frag, ref_.offset(coord_offset) + byte_offset * 8 / sizeof_bits::value); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation + } +}; + +/// Wrapper for ColumnMajor +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::ColumnMajor, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIteratorCanonical< + Shape, kOperand, Element, + layout::ColumnMajor, + InstructionShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + TensorCoord const & extent, + int lane_id + ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.contiguous(), tile_offset.strided()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + + +/// Wrapper for RowMajor +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Identifies A or B multiplicand + Operand Operand_, + /// Data type of elements + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions) + int OpDelta_, + /// Number of partitions along K dimension + int PartitionsK_> +class MmaTensorOpMultiplicandTileIterator< + Shape_, Operand_, Element_, + cutlass::layout::RowMajor, + InstructionShape_, OpDelta_, 32, PartitionsK_> { + public: + + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Operand tag + static Operand const kOperand = Operand_; + + static_assert(kOperand == Operand::kA || kOperand== Operand::kB, + "MmaTensorOpMultiplicandIterator may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Underlying tile iterator implementation + using Base = MmaTensorOpMultiplicandTileIteratorCanonical< + Shape, kOperand, Element, + layout::RowMajor, + InstructionShape, + kOpDelta, kThreads, PartitionsK_>; + + public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + +private: + + /// Underlying tile iterator + Base iterator_; + +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + int lane_id + ): iterator_({ref.data(), ref.stride()}, lane_id) { + } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator( + TensorRef const &ref, + TensorCoord const &extent, + int lane_id + ): iterator_({ref.data(), ref.stride()}, extent, lane_id) { + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_pointer_offset(LongIndex offset) { + + iterator_.add_pointer_offset(offset); + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator &add_tile_offset(TensorCoord const &tile_offset) { + + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator++() { + + ++iterator_; + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + MmaTensorOpMultiplicandTileIterator & operator--() { + + --iterator_; + + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator+=(TensorCoord const &tile_offset) { + add_tile_offset(PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of the tensor + CUTLASS_DEVICE + MmaTensorOpMultiplicandTileIterator & operator-=(TensorCoord const &tile_offset) { + add_tile_offset(-PitchLinearCoord(tile_offset.row(), tile_offset.column())); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + iterator_.load(frag); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index byte_offset) const { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + // TODO + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + iterator_.load_with_byte_offset( + frag, + {tile_offset.contiguous(), tile_offset.strided()}, + byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + iterator_.set_kgroup_index(k_group); + } +}; + + //////////////////////////////////////////////////////////////////////////////// } // namespace warp diff --git a/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h new file mode 100644 index 0000000000..a7e69816f1 --- /dev/null +++ b/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator_sparse.h @@ -0,0 +1,374 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators to load sparse meta data used by warp-level matrix multiply operations + targeting Sparse Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Data type of A elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + int OpDelta_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class SparseMmaTensorOpMetaTileIterator { + public: + /// Shape of tile to load (concept: PitchLinearShape) + using Shape = Shape_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = Layout_; + + /// Shape of one matrix product operation (concept: GemmShape) + using InstructionShape = InstructionShape_; + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = OpDelta_; + + /// Number of participating threads + static int const kThreads = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + + static int const kSparse = 2; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kColumn % InstructionShape::kColumn), + "Shape of warp-level Mma must be divisible by operator shape."); + + static int const kElementsPerAccess = 128 / sizeof_bits::value; + + // Determine number of elements along outer dimension per individual LDSM op + static int const kLdsmOpOuter = InstructionShape::kColumn; + static int const kLdsmOpInner = 8 * kElementsPerAccess / kLdsmOpOuter; + + static_assert(!(Shape::kColumn % kLdsmOpOuter), + "Shape of warp-level mma must be divisible by LDSM's " + "fundamental tile size."); + + static_assert(!(Shape::kRow % kLdsmOpInner), + "Shape of warp-level mma must be divisible by LDSM's " + "fundamental tile size."); + + /// Shape of one individual LDSM instruction + static int const LdsmShapeColumn = + InstructionShape::kColumn / kLdsmOpOuter; + static int const LdsmShapeRow = + ((4 / LdsmShapeColumn * kLdsmOpInner) > Shape::kRow) + ? (Shape::kRow / kLdsmOpInner) + : (4 / LdsmShapeColumn); + using LdsmShape = + layout::PitchLinearShape; + + /// Number and arrangement of LDSM instructions + using LdsmIterations = layout::PitchLinearShape< + Shape::kRow / kLdsmOpInner / LdsmShapeRow, + 1>; + + /// Number of groups for each tile + static int const kGroupsPerTile = + Shape::kColumn / InstructionShape::kColumn; + }; + + private: + /// Not working on this feature at the moment. + static_assert(kOpDelta == 1, + "Alternative arrangements not supported at present."); + + /// Pointer type used for accesses + using AccessType = Array; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = + Array; + + private: + + /// Layout object storing stride values + Index stride_; + + /// Shared memory base pointers - not advanced + AccessType const *pointer_; + + /// Byte offset incremented as iterator advances + Index byte_offset_; + + /// Internal counter used to determine when to increment byte offset and when + /// to XOR it + int k_group_idx_; + + public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + SparseMmaTensorOpMetaTileIterator() + : pointer_(nullptr), + stride_(0), + byte_offset_(0), + k_group_idx_(0) {} + + /// Constructor from TensorRef + CUTLASS_DEVICE + SparseMmaTensorOpMetaTileIterator(TensorRef const &ref, int lane_id) + : pointer_(reinterpret_cast(ref.data())), + stride_(ref.stride(0) / Policy::kElementsPerAccess), + byte_offset_(0), + k_group_idx_(0) { + + int access_contiguous = (lane_id % (Shape::kRow / Policy::kElementsPerAccess)); + int access_strided = (lane_id / (Shape::kRow / Policy::kElementsPerAccess)); + + byte_offset_ = (access_contiguous + access_strided * stride_) * + sizeof_bits::value * Policy::kElementsPerAccess / 8; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_DEVICE + SparseMmaTensorOpMetaTileIterator &add_pointer_offset(LongIndex offset) { + byte_offset_ += offset * sizeof_bits::value / 8; + + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + SparseMmaTensorOpMetaTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int offset = tile_offset.row() * Shape::kRow + + tile_offset.column() * InstructionShape::kColumn * stride_ * + Policy::kElementsPerAccess; + + add_pointer_offset(offset); + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + SparseMmaTensorOpMetaTileIterator &operator++() { + add_tile_offset({0, 1}); + + if (kPartitionsK > 1) { + ++k_group_idx_; + // Jump to next stage + if (k_group_idx_ == Policy::kGroupsPerTile) { + k_group_idx_ = 0; + add_tile_offset( + {0, ((kPartitionsK - 1) * Policy::kGroupsPerTile)}); + } + } + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + SparseMmaTensorOpMetaTileIterator &operator--(){ + byte_offset_ -= stride_ * InstructionShape::kColumn * + sizeof_bits::value * Policy::kElementsPerAccess / + 8; + } + + ///< advances in units of whole tiles along the logical coordinate space of + ///< the tensor + CUTLASS_DEVICE SparseMmaTensorOpMetaTileIterator & + operator+=(TensorCoord const &tile_offset) { + add_tile_offset(tile_offset); + return *this; + } + + ///< advances in units of whole tiles along the logical coordinate space of + ///< the tensor + CUTLASS_DEVICE + SparseMmaTensorOpMetaTileIterator &operator-=( + TensorCoord const &tile_offset) { + add_tile_offset(-tile_offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { load_with_byte_offset(frag, 0); } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset in units of bytes + Index byte_offset) const { + Array *fetch_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < Policy::LdsmIterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < Policy::LdsmIterations::kContiguous; ++c) { + + int access_idx = c + s * Policy::LdsmIterations::kContiguous; + + AccessType const *source_ptr = + pointer_ + + Policy::LdsmShape::kContiguous * Policy::kLdsmOpInner * c + + Policy::LdsmShape::kStrided * s * stride_; + + char const *source_byte_ptr = reinterpret_cast(source_ptr) + + byte_offset + byte_offset_; + + cutlass::arch::ldsm( + fetch_ptr[access_idx], source_byte_ptr); + } + } + } + + /// Loads a fragment from memory with additional logical offset + CUTLASS_DEVICE + void load_with_pointer_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a linear offset + Index pointer_offset) const { + load_with_byte_offset(frag, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset) const { + load_with_byte_offset(frag, tile_offset, 0); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index pointer_offset) const { + load_with_byte_offset(frag, tile_offset, pointer_offset * sizeof(Element)); + } + + /// Loads a fragment from memory with logical offset in units of whole tiles. + CUTLASS_DEVICE + void load_with_byte_offset( + /// fragment to load from the tensor + Fragment &frag, + /// loads a tile with a logical offset in units of whole tiles + TensorCoord const &tile_offset, + /// loads a tile with a logical offset AND a pointer offset + Index byte_offset) const { + Index pointer_offset = + tile_offset.contiguous() * Shape::kRow / Layout::kElementsPerAccess + + tile_offset.strided() * InstructionShape::kColumn * stride_; + + byte_offset += sizeof(AccessType) * pointer_offset; + + load_with_byte_offset(frag, byte_offset); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index to enable the compiler to + /// fold constants and achieve more efficient code. + /// + /// This is used by some nontrivial permuted layouts. + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no op + } +}; + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/half.h b/include/cutlass/half.h index 10d00de1c2..3d0bd34724 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -349,7 +349,7 @@ struct alignas(2) half_t { /// Default constructor CUTLASS_HOST_DEVICE - half_t() { } + half_t() : storage(0) { } /// Reinterpret cast from CUDA's half type CUTLASS_HOST_DEVICE diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index 6b97f8222a..df32042d0e 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -83,11 +83,10 @@ struct integer_subbyte { integer_subbyte(unsigned value) : storage(reinterpret_cast(value) & kMask) {} - /// Conversion from double CUTLASS_HOST_DEVICE integer_subbyte(double value) { - T tmp = (T)value; - storage = reinterpret_cast(tmp) & kMask; + T tmp = static_cast(value); + storage = Storage(reinterpret_cast(tmp) & kMask); } /// @@ -155,6 +154,12 @@ struct integer_subbyte { /// 1-bit Unsigned integer type using uint1b_t = integer_subbyte<1, false>; +/// 2-bit Integer type +using int2b_t = integer_subbyte<2, true>; + +/// 2-bit Unsigned integer type +using uint2b_t = integer_subbyte<2, false>; + /// 4-bit Integer type using int4b_t = integer_subbyte<4, true>; @@ -169,6 +174,18 @@ struct sizeof_bits { static int const value = 1; }; +/// Defines the size of an element in bits - specialized for int2b_t +template <> +struct sizeof_bits { + static int const value = 2; +}; + +/// Defines the size of an element in bits - specialized for uint2b_t +template <> +struct sizeof_bits { + static int const value = 2; +}; + /// Defines the size of an element in bits - specialized for int4b_t template <> struct sizeof_bits { diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index 7c02f8f2c2..0590492625 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -35,7 +35,6 @@ #include "cutlass/cutlass.h" #include "cutlass/matrix_coord.h" -#include "cutlass/matrix_traits.h" namespace cutlass { namespace layout { @@ -803,7 +802,7 @@ struct GeneralMatrix { // Data members // - MatrixLayout layout_id_; + Matrix layout_id_; /// Stride data member Stride stride_; @@ -815,12 +814,12 @@ struct GeneralMatrix { /// Ctor CUTLASS_HOST_DEVICE - GeneralMatrix(): layout_id_(MatrixLayout::kColumnMajor), stride_(make_Coord(0, 1)) { } + GeneralMatrix(): layout_id_(Matrix::kColumnMajor), stride_(make_Coord(0, 1)) { } /// Ctor CUTLASS_HOST_DEVICE GeneralMatrix( - MatrixLayout layout_id, + Matrix layout_id, Index ldm, Index interleave): layout_id_(layout_id), stride_(make_Coord(ldm, interleave)) { } @@ -828,11 +827,11 @@ struct GeneralMatrix { CUTLASS_HOST_DEVICE static GeneralMatrix packed( MatrixCoord const &extent, - MatrixLayout layout_id = MatrixLayout::kColumnMajor, + Matrix layout_id = Matrix::kColumnMajor, Index interleave = 1) { Index c; - if (layout_id == MatrixLayout::kRowMajor) { + if (layout_id == Matrix::kRowMajor) { c = extent.column(); } else { @@ -849,7 +848,7 @@ struct GeneralMatrix { CUTLASS_HOST_DEVICE LongIndex operator()(MatrixCoord const &coord) const { Index c, s; - if (layout_id_ == MatrixLayout::kRowMajor) { + if (layout_id_ == Matrix::kRowMajor) { c = coord.column(); s = coord.row(); } @@ -871,7 +870,7 @@ struct GeneralMatrix { } CUTLASS_HOST_DEVICE - MatrixLayout layout_id() const { + Matrix layout_id() const { return layout_id_; } @@ -882,7 +881,7 @@ struct GeneralMatrix { } CUTLASS_HOST_DEVICE - MatrixLayout & layout_id() { + Matrix & layout_id() { return layout_id_; } @@ -902,7 +901,7 @@ struct GeneralMatrix { CUTLASS_HOST_DEVICE LongIndex capacity(MatrixCoord const &extent) const { Index s; - if (layout_id_ == MatrixLayout::kRowMajor) { + if (layout_id_ == Matrix::kRowMajor) { s = extent.row(); } else { diff --git a/include/cutlass/layout/tensor.h b/include/cutlass/layout/tensor.h index 20d5bad777..f3d5a12bf8 100644 --- a/include/cutlass/layout/tensor.h +++ b/include/cutlass/layout/tensor.h @@ -79,7 +79,7 @@ class TensorNHWC { // Data members // - /// Stride data member - [c, wc, hwc] + /// Stride data member - [stride_w, stride_h, stride_n] Stride stride_; public: @@ -93,7 +93,12 @@ class TensorNHWC { /// Constructor CUTLASS_HOST_DEVICE - TensorNHWC(typename Stride::Index c, typename Stride::Index wc, typename Stride::Index hwc): stride_(make_Coord(c, wc, hwc)) { } + TensorNHWC( + typename Stride::Index stride_w, ///< number of elements between adjacent W coordinates + typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_w, stride_h, stride_n)) { } /// Helper returns a layout to a tightly packed NHWC tensor. CUTLASS_HOST_DEVICE @@ -116,12 +121,6 @@ class TensorNHWC { LongIndex(stride_[2] * coord.n()); } - /// Returns a RowMajor equivalent for a TensorNHWC layout - CUTLASS_HOST_DEVICE - explicit operator RowMajor() { - return RowMajor(stride_[0]); - } - /// Returns the logical coordinate (n, h, w, c) from a given offset in linear memory. CUTLASS_HOST_DEVICE TensorCoord inverse(LongIndex index) const { @@ -444,6 +443,107 @@ class TensorCxRSKx { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 5-D NDHWC tensors. +class TensorNDHWC { +public: + /// Logical rank of tensor + static int const kRank = 5; + + /// Rank of stride vector + static int const kStrideRank = 4; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, d, h, w, c) + using TensorCoord = Tensor5DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [c, wc, hwc, dhwc] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNDHWC(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNDHWC( + typename Stride::Index c, + typename Stride::Index wc, + typename Stride::Index hwc, + typename Stride::Index dhwc): + stride_(make_Coord(c, wc, hwc, dhwc)) { } + + /// Helper returns a layout to a tightly packed NHWC tensor. + CUTLASS_HOST_DEVICE + static TensorNDHWC packed(TensorCoord const &extent) { + return TensorNDHWC( + make_Coord( + extent.c(), + extent.w() * extent.c(), + extent.h() * extent.w() * extent.c(), + extent.d() * extent.h() * extent.w() * extent.c() + ) + ); + } + + /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.c() + + LongIndex(stride_[0] * coord.w()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.d()) + + LongIndex(stride_[3] * coord.n()); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.c() > stride_[0]) + || (extent.w() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2]) + || (extent.d() * stride_[2] > stride_[3])) { + assert(0); + } + return extent.n() * stride_[3]; + } +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layout diff --git a/include/cutlass/layout/tensor_op_multiplicand_sm75.h b/include/cutlass/layout/tensor_op_multiplicand_sm75.h index 00870fb50f..b52483355c 100644 --- a/include/cutlass/layout/tensor_op_multiplicand_sm75.h +++ b/include/cutlass/layout/tensor_op_multiplicand_sm75.h @@ -81,17 +81,23 @@ struct TensorOpMultiplicand { static int const kFactor = kTileShapeContiguous * kElementsPerAccess / kCrosswise; - /// The strided dimension needs to be at least WarpSize(32) / - /// kTileShapeContiguous for a warp to access. To ensure conflict free + static_assert( + (kFactor > 0), + "kCrosswise should be no large than one shared memory cache line."); + + /// The strided dimension needs to be at least (WarpSize(32) / + /// kTileShapeContiguous) for a warp to access. To ensure conflict free /// access, it also needs to be at least (kTileShapeContiguous / kFactor). + /// See comments below static int const kTileShapeStride = ((kTileShapeContiguous / kFactor) > (32 / kTileShapeContiguous)) ? (kTileShapeContiguous / kFactor) : (32 / kTileShapeContiguous); - /// Fundamental tile shape in units of vectors - /// For TN kblock=32 and 8x8x16 shapes, TileShape = <8, 4>. - /// For the rest, TileShape = <8, 8> + /// Fundamental tile shape in units of vectors to guarantee bank conflict free + /// shared memory load/store. + /// For kFactor = 1, TileShape = <8, 8> + /// For kFactor > 1, TileShape = <8, 4> using TileShape = PitchLinearShape; /// Fundamental partition shape in units of vectors diff --git a/include/cutlass/matrix.h b/include/cutlass/matrix.h new file mode 100644 index 0000000000..5d05ee8994 --- /dev/null +++ b/include/cutlass/matrix.h @@ -0,0 +1,14111 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Matrix classes with value semantics. +*/ + +#pragma once + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Primary template with partial specializations to follow +template struct Matrix; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 1-by-2 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 1; + + /// Number of columns in matrix + static int const kColumns = 2; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 2; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 1-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 1-by-2 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1 + ) { + + data[0] = _0_0; data[1] = _0_1; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[1] = data[1]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 1 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 1 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x2(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x2(v, i, 0); + } + + /// Forms a 1-by-2 matrix by horizontally concatenating an Element with an Element + CUTLASS_HOST_DEVICE + static Matrix hcat(Element lhs, Element rhs) { + return Matrix( + lhs, rhs); + } + + /// Concatenates this matrix with a an Element to form a 1-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Element rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 1-by-2 matrix to form a 1-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 1-by-2 matrix to form a 2-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-2 matrix to form a 3-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 3-by-2 matrix to form a 4-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Elementwise add operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + + return result; + } + + /// Elementwise add operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + + return *this; + } + + /// Elementwise subtract operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + + return result; + } + + /// Elementwise subtract operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + + return *this; + } + + /// Elementwise multiply operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + + return result; + } + + /// Scalar multiply operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + + return result; + } + + /// Scalar multiply operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + + return *this; + } + + /// Elementwise divide operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + + return result; + } + + /// Scalar divide operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + + return result; + } + + /// Scalar divide operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + + return *this; + } + + /// Elementwise divide operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (1-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + + return m; + } + + /// Matrix product of size 1-by-1-by-2 + CUTLASS_HOST_DEVICE + Element product(Matrix const &rhs, Element accum = Element()) const { + + // k=0 + accum += data[0] * rhs.data[0]; + + // k=1 + accum += data[1] * rhs.data[1]; + + return accum; + } + + /// Matrix product of size 1-by-1-by-2 + CUTLASS_HOST_DEVICE + Element operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 1-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 1-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 1-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 1-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Dot product of vectors with extent 2 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + return accum; + } + + /// Dot product of vectors with extent 2 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + return accum; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + + return accum; + } + +}; + +/// Template alias for 1-by-2 matrix +template +using Matrix1x2 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix1x2 make_Matrix1x2( + Element _0_0, Element _0_1 +) { + return Matrix1x2( + _0_0, _0_1 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 1-by-3 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 1; + + /// Number of columns in matrix + static int const kColumns = 3; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 3; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 1-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 1-by-3 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[1] = data[1]; + mt.data[2] = data[2]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 1 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 1 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x3(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x3(v, i, 0); + } + + /// Forms a 1-by-3 matrix by horizontally concatenating an Element with a 1-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Element lhs, Matrix const & rhs) { + return Matrix( + lhs, rhs.at(0, 0), rhs.at(0, 1)); + } + + /// Forms a 1-by-3 matrix by horizontally concatenating a 1-by-2 matrix with an Element + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Element rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs); + } + + /// Concatenates this matrix with a an Element to form a 1-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Element rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 1-by-3 matrix to form a 2-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-3 matrix to form a 3-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 3-by-3 matrix to form a 4-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Elementwise add operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + + return result; + } + + /// Elementwise add operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + + return *this; + } + + /// Elementwise subtract operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + + return result; + } + + /// Elementwise subtract operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + + return *this; + } + + /// Elementwise multiply operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + + return result; + } + + /// Scalar multiply operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + + return result; + } + + /// Scalar multiply operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + + return *this; + } + + /// Elementwise divide operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + + return result; + } + + /// Scalar divide operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + + return result; + } + + /// Scalar divide operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + + return *this; + } + + /// Elementwise divide operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (1-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + + return m; + } + + /// Matrix product of size 1-by-1-by-3 + CUTLASS_HOST_DEVICE + Element product(Matrix const &rhs, Element accum = Element()) const { + + // k=0 + accum += data[0] * rhs.data[0]; + + // k=1 + accum += data[1] * rhs.data[1]; + + // k=2 + accum += data[2] * rhs.data[2]; + + return accum; + } + + /// Matrix product of size 1-by-1-by-3 + CUTLASS_HOST_DEVICE + Element operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 1-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + + return accum; + } + + /// Matrix product of size 1-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 1-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 1-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Dot product of vectors with extent 3 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + return accum; + } + + /// Dot product of vectors with extent 3 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + return accum; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + + return accum; + } + + /// Cross product + CUTLASS_HOST_DEVICE + Matrix cross(Matrix const &rhs) const { + return Matrix( + data[1] * rhs.data[2] - data[2] * rhs.data[1], + data[0] * rhs.data[2] - data[2] * rhs.data[1], + data[0] * rhs.data[1] - data[1] * rhs.data[0] + ); + } + +}; + +/// Template alias for 1-by-3 matrix +template +using Matrix1x3 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix1x3 make_Matrix1x3( + Element _0_0, Element _0_1, Element _0_2 +) { + return Matrix1x3( + _0_0, _0_1, _0_2 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 1-by-4 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 1; + + /// Number of columns in matrix + static int const kColumns = 4; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 4; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 1-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 1-by-4 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2, Element _0_3 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[1] = data[1]; + mt.data[2] = data[2]; + mt.data[3] = data[3]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 1 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 1 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x4(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x4(v, i, 0); + } + + /// Forms a 1-by-4 matrix by horizontally concatenating an Element with a 1-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Element lhs, Matrix const & rhs) { + return Matrix( + lhs, rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2)); + } + + /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-2 matrix with a 1-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1)); + } + + /// Forms a 1-by-4 matrix by horizontally concatenating a 1-by-3 matrix with an Element + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Element rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs); + } + + /// Concatenates this matrix with a a 1-by-4 matrix to form a 2-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-4 matrix to form a 3-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 3-by-4 matrix to form a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Elementwise add operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + result.data[3] = data[3] + rhs.data[3]; + + return result; + } + + /// Elementwise add operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + data[3] += rhs.data[3]; + + return *this; + } + + /// Elementwise subtract operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + result.data[3] = data[3] - rhs.data[3]; + + return result; + } + + /// Elementwise subtract operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + data[3] -= rhs.data[3]; + + return *this; + } + + /// Elementwise multiply operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + result.data[3] = data[3] * rhs.data[3]; + + return result; + } + + /// Scalar multiply operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + result.data[3] = data[3] * s; + + return result; + } + + /// Scalar multiply operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + data[3] *= s; + + return *this; + } + + /// Elementwise divide operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + result.data[3] = data[3] / rhs.data[3]; + + return result; + } + + /// Scalar divide operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + result.data[3] = data[3] / s; + + return result; + } + + /// Scalar divide operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + data[3] /= s; + + return *this; + } + + /// Elementwise divide operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (1-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + data[3] /= rhs.data[3]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + + return m; + } + + /// Matrix product of size 1-by-1-by-4 + CUTLASS_HOST_DEVICE + Element product(Matrix const &rhs, Element accum = Element()) const { + + // k=0 + accum += data[0] * rhs.data[0]; + + // k=1 + accum += data[1] * rhs.data[1]; + + // k=2 + accum += data[2] * rhs.data[2]; + + // k=3 + accum += data[3] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 1-by-1-by-4 + CUTLASS_HOST_DEVICE + Element operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + + // k=3 + accum.data[0] += data[3] * rhs.data[6]; + accum.data[1] += data[3] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 1-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + + // k=3 + accum.data[0] += data[3] * rhs.data[9]; + accum.data[1] += data[3] * rhs.data[10]; + accum.data[2] += data[3] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 1-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + + // k=3 + accum.data[0] += data[3] * rhs.data[12]; + accum.data[1] += data[3] * rhs.data[13]; + accum.data[2] += data[3] * rhs.data[14]; + accum.data[3] += data[3] * rhs.data[15]; + + return accum; + } + + /// Matrix product of size 1-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 1-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Dot product of vectors with extent 4 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + accum += data[3] * rhs.data[3]; + return accum; + } + + /// Dot product of vectors with extent 4 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + accum += data[3] * rhs.data[3]; + return accum; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + + return accum; + } + +}; + +/// Template alias for 1-by-4 matrix +template +using Matrix1x4 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix1x4 make_Matrix1x4( + Element _0_0, Element _0_1, Element _0_2, Element _0_3 +) { + return Matrix1x4( + _0_0, _0_1, _0_2, _0_3 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 2-by-1 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 2; + + /// Number of columns in matrix + static int const kColumns = 1; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 2; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 2-by-1 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 2-by-1 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, + Element _1_0 + ) { + + data[0] = _0_0; + data[1] = _1_0; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[1] = data[1]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 1 + j + 0]; + m.data[1] = data[i * 1 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 1 + j + 0] = m.data[0]; + data[i * 1 + j + 1] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_2x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_2x1(v, 0, j); + } + + /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-3 matrix to form a 2-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 2-by-1 matrix by vertically concatenating an Element with an Element + CUTLASS_HOST_DEVICE + static Matrix vcat(Element upper, Element lower) { + return Matrix( + upper + , lower); + } + + /// Concatenates this matrix with a an Element to form a 3-by-1 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Element rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-1 matrix to form a 4-by-1 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Elementwise add operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + + result.data[1] = data[1] + rhs.data[1]; + + return result; + } + + /// Elementwise add operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + + data[1] += rhs.data[1]; + + return *this; + } + + /// Elementwise subtract operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + + result.data[1] = data[1] - rhs.data[1]; + + return result; + } + + /// Elementwise subtract operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + + data[1] -= rhs.data[1]; + + return *this; + } + + /// Elementwise multiply operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + + result.data[1] = data[1] * rhs.data[1]; + + return result; + } + + /// Scalar multiply operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + + result.data[1] = data[1] * s; + + return result; + } + + /// Scalar multiply operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + + data[1] *= s; + + return *this; + } + + /// Elementwise divide operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + + result.data[1] = data[1] / rhs.data[1]; + + return result; + } + + /// Scalar divide operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + + result.data[1] = data[1] / s; + + return result; + } + + /// Scalar divide operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + + data[1] /= s; + + return *this; + } + + /// Elementwise divide operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (2-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + + data[1] /= rhs.data[1]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + + return m; + } + + /// Matrix product of size 2-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[1] * rhs.data[0]; + + return accum; + } + + /// Matrix product of size 2-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 2-by-2-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[1] * rhs.data[0]; + accum.data[3] += data[1] * rhs.data[1]; + + return accum; + } + + /// Matrix product of size 2-by-2-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-3-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[1] * rhs.data[0]; + accum.data[4] += data[1] * rhs.data[1]; + accum.data[5] += data[1] * rhs.data[2]; + + return accum; + } + + /// Matrix product of size 2-by-3-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-4-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[1] * rhs.data[0]; + accum.data[5] += data[1] * rhs.data[1]; + accum.data[6] += data[1] * rhs.data[2]; + accum.data[7] += data[1] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 2-by-4-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Dot product of vectors with extent 2 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + return accum; + } + + /// Dot product of vectors with extent 2 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + return accum; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + + return accum; + } + +}; + +/// Template alias for 2-by-1 matrix +template +using Matrix2x1 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix2x1 make_Matrix2x1( + Element _0_0, + Element _1_0 +) { + return Matrix2x1( + _0_0, + _1_0 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 2-by-2 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 2; + + /// Number of columns in matrix + static int const kColumns = 2; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 4; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 2-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 2-by-2 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, + Element _1_0, Element _1_1 + ) { + + data[0] = _0_0; data[1] = _0_1; + data[2] = _1_0; data[3] = _1_1; + } + + /// Constucts a 2-by-2 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_1.data[0]; + data[3] = row_1.data[1]; + } + + /// Static method to construct a 2-by-2 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_0.data[1]; + result.data[3] = column_1.data[1]; + return result; + } + + /// Constructs an identity matrix + CUTLASS_HOST_DEVICE + static Matrix identity() { + Matrix m; + + m.data[0] = Element(1); + m.data[3] = Element(1); + + return m; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[3] = diag.data[1]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[3] = diag.data[1]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[3]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[2] = data[1]; + mt.data[1] = data[2]; + mt.data[3] = data[3]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x2(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x2(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 2] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_2x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_2x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + m.data[2] = data[i * 2 + j + 2]; + m.data[3] = data[i * 2 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + data[i * 2 + j + 2] = m.data[2]; + data[i * 2 + j + 3] = m.data[3]; + + return *this; + } + + /// Forms a 2-by-2 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0) + , lhs.at(1, 0), rhs.at(1, 0)); + } + + /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-2 matrix to form a 2-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 2-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 1-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1) + , lower.at(0, 0), lower.at(0, 1)); + } + + /// Concatenates this matrix with a a 1-by-2 matrix to form a 3-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-2 matrix to form a 4-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Forms a 2-by-2 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Element B, + Element C, Element D) { + return Matrix( + A, B + , C, D + ); + } + + /// Elementwise add operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + + result.data[2] = data[2] + rhs.data[2]; + result.data[3] = data[3] + rhs.data[3]; + + return result; + } + + /// Elementwise add operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + + data[2] += rhs.data[2]; + data[3] += rhs.data[3]; + + return *this; + } + + /// Elementwise subtract operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + + result.data[2] = data[2] - rhs.data[2]; + result.data[3] = data[3] - rhs.data[3]; + + return result; + } + + /// Elementwise subtract operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + + data[2] -= rhs.data[2]; + data[3] -= rhs.data[3]; + + return *this; + } + + /// Elementwise multiply operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + + result.data[2] = data[2] * rhs.data[2]; + result.data[3] = data[3] * rhs.data[3]; + + return result; + } + + /// Scalar multiply operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + + result.data[2] = data[2] * s; + result.data[3] = data[3] * s; + + return result; + } + + /// Scalar multiply operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + + data[2] *= s; + data[3] *= s; + + return *this; + } + + /// Elementwise divide operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + + result.data[2] = data[2] / rhs.data[2]; + result.data[3] = data[3] / rhs.data[3]; + + return result; + } + + /// Scalar divide operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + + result.data[2] = data[2] / s; + result.data[3] = data[3] / s; + + return result; + } + + /// Scalar divide operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + + data[2] /= s; + data[3] /= s; + + return *this; + } + + /// Elementwise divide operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (2-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + + data[2] /= rhs.data[2]; + data[3] /= rhs.data[3]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + + return m; + } + + /// Matrix product of size 2-by-1-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[2] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[3] * rhs.data[1]; + + return accum; + } + + /// Matrix product of size 2-by-1-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[2] * rhs.data[0]; + accum.data[3] += data[2] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[3] * rhs.data[2]; + accum.data[3] += data[3] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 2-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 2-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[2] * rhs.data[0]; + accum.data[4] += data[2] * rhs.data[1]; + accum.data[5] += data[2] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[3] * rhs.data[3]; + accum.data[4] += data[3] * rhs.data[4]; + accum.data[5] += data[3] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 2-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[2] * rhs.data[0]; + accum.data[5] += data[2] * rhs.data[1]; + accum.data[6] += data[2] * rhs.data[2]; + accum.data[7] += data[2] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[3] * rhs.data[4]; + accum.data[5] += data[3] * rhs.data[5]; + accum.data[6] += data[3] * rhs.data[6]; + accum.data[7] += data[3] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 2-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[3]; + + return accum; + } + + /// Returns 2-by-2 rotation matrix + CUTLASS_HOST_DEVICE + static Matrix rotation(Element theta) { + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + return Matrix( + c, -s, + s, c + ); + } + + /// Computes the determinant of a 2-by-2 matrix + CUTLASS_HOST_DEVICE + Element determinant(Element accum = Element()) const { + accum += data[0] * data[3] - data[1] * data[2]; + + return accum; + } + + /// Computes the inverse of a 2-by-2 matrix given + /// the matrix's determinant + CUTLASS_HOST_DEVICE + Matrix inverse(Element det) const { + return Matrix( + data[3], -data[1], + -data[2], data[0] + ) * (Element(1) / det); + } + + /// Computes the inverse of a 2-by-2 matrix. + CUTLASS_HOST_DEVICE + Matrix inverse() const { + return inverse(determinant()); + } + +}; + +/// Template alias for 2-by-2 matrix +template +using Matrix2x2 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix2x2 make_Matrix2x2( + Element _0_0, Element _0_1, + Element _1_0, Element _1_1 +) { + return Matrix2x2( + _0_0, _0_1, + _1_0, _1_1 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 2-by-3 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 2; + + /// Number of columns in matrix + static int const kColumns = 3; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 6; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 2-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 2-by-3 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2, + Element _1_0, Element _1_1, Element _1_2 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; + data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; + } + + /// Constucts a 2-by-3 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_0.data[2]; + data[3] = row_1.data[0]; + data[4] = row_1.data[1]; + data[5] = row_1.data[2]; + } + + /// Static method to construct a 2-by-3 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1, + Matrix const &column_2 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_2.data[0]; + result.data[3] = column_0.data[1]; + result.data[4] = column_1.data[1]; + result.data[5] = column_2.data[1]; + return result; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[3] = diag.data[1]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[3] = diag.data[1]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[3]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[2] = data[1]; + mt.data[4] = data[2]; + mt.data[1] = data[3]; + mt.data[3] = data[4]; + mt.data[5] = data[5]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x3(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x3(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 3] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_2x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_2x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 3]; + m.data[3] = data[i * 3 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 3] = m.data[2]; + data[i * 3 + j + 4] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + m.data[3] = data[i * 3 + j + 3]; + m.data[4] = data[i * 3 + j + 4]; + m.data[5] = data[i * 3 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + data[i * 3 + j + 3] = m.data[3]; + data[i * 3 + j + 4] = m.data[4]; + data[i * 3 + j + 5] = m.data[5]; + + return *this; + } + + /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) + , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1)); + } + + /// Forms a 2-by-3 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) + , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0)); + } + + /// Concatenates this matrix with a a 2-by-1 matrix to form a 2-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 2-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 1-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); + } + + /// Concatenates this matrix with a a 1-by-3 matrix to form a 3-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-3 matrix to form a 4-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Forms a 2-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Matrix const & B, + Element C, Matrix const & D) { + return Matrix( + A, B.at(0, 0), B.at(0, 1) + , C, D.at(0, 0), D.at(0, 1) + ); + } + + /// Forms a 2-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Element B, + Matrix const & C, Element D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B + , C.at(0, 0), C.at(0, 1), D + ); + } + + /// Elementwise add operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + + result.data[3] = data[3] + rhs.data[3]; + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + + return result; + } + + /// Elementwise add operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + + data[3] += rhs.data[3]; + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + + return *this; + } + + /// Elementwise subtract operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + + result.data[3] = data[3] - rhs.data[3]; + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + + return result; + } + + /// Elementwise subtract operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + + data[3] -= rhs.data[3]; + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + + return *this; + } + + /// Elementwise multiply operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + + result.data[3] = data[3] * rhs.data[3]; + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + + return result; + } + + /// Scalar multiply operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + + result.data[3] = data[3] * s; + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + + return result; + } + + /// Scalar multiply operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + + data[3] *= s; + data[4] *= s; + data[5] *= s; + + return *this; + } + + /// Elementwise divide operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + + result.data[3] = data[3] / rhs.data[3]; + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + + return result; + } + + /// Scalar divide operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + + result.data[3] = data[3] / s; + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + + return result; + } + + /// Scalar divide operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + + data[3] /= s; + data[4] /= s; + data[5] /= s; + + return *this; + } + + /// Elementwise divide operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (2-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + + data[3] /= rhs.data[3]; + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + + return m; + } + + /// Matrix product of size 2-by-1-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[3] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[4] * rhs.data[1]; + + // k=2 + accum.data[0] += data[2] * rhs.data[2]; + accum.data[1] += data[5] * rhs.data[2]; + + return accum; + } + + /// Matrix product of size 2-by-1-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[3] * rhs.data[0]; + accum.data[3] += data[3] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[4] * rhs.data[2]; + accum.data[3] += data[4] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + accum.data[2] += data[5] * rhs.data[4]; + accum.data[3] += data[5] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 2-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[3] * rhs.data[0]; + accum.data[4] += data[3] * rhs.data[1]; + accum.data[5] += data[3] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[4] * rhs.data[3]; + accum.data[4] += data[4] * rhs.data[4]; + accum.data[5] += data[4] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + accum.data[3] += data[5] * rhs.data[6]; + accum.data[4] += data[5] * rhs.data[7]; + accum.data[5] += data[5] * rhs.data[8]; + + return accum; + } + + /// Matrix product of size 2-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 2-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[3] * rhs.data[0]; + accum.data[5] += data[3] * rhs.data[1]; + accum.data[6] += data[3] * rhs.data[2]; + accum.data[7] += data[3] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[4] * rhs.data[4]; + accum.data[5] += data[4] * rhs.data[5]; + accum.data[6] += data[4] * rhs.data[6]; + accum.data[7] += data[4] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + accum.data[4] += data[5] * rhs.data[8]; + accum.data[5] += data[5] * rhs.data[9]; + accum.data[6] += data[5] * rhs.data[10]; + accum.data[7] += data[5] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 2-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[4]; + + return accum; + } + +}; + +/// Template alias for 2-by-3 matrix +template +using Matrix2x3 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix2x3 make_Matrix2x3( + Element _0_0, Element _0_1, Element _0_2, + Element _1_0, Element _1_1, Element _1_2 +) { + return Matrix2x3( + _0_0, _0_1, _0_2, + _1_0, _1_1, _1_2 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 2-by-4 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 2; + + /// Number of columns in matrix + static int const kColumns = 4; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 8; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 2-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 2-by-4 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2, Element _0_3, + Element _1_0, Element _1_1, Element _1_2, Element _1_3 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; + data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; + } + + /// Constucts a 2-by-4 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_0.data[2]; + data[3] = row_0.data[3]; + data[4] = row_1.data[0]; + data[5] = row_1.data[1]; + data[6] = row_1.data[2]; + data[7] = row_1.data[3]; + } + + /// Static method to construct a 2-by-4 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1, + Matrix const &column_2, + Matrix const &column_3 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_2.data[0]; + result.data[3] = column_3.data[0]; + result.data[4] = column_0.data[1]; + result.data[5] = column_1.data[1]; + result.data[6] = column_2.data[1]; + result.data[7] = column_3.data[1]; + return result; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + m.data[6] = s; + m.data[7] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[3] = diag.data[1]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[3] = diag.data[1]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[3]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[2] = data[1]; + mt.data[4] = data[2]; + mt.data[6] = data[3]; + mt.data[1] = data[4]; + mt.data[3] = data[5]; + mt.data[5] = data[6]; + mt.data[7] = data[7]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 2 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x4(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x4(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 4] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_2x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_2x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 4]; + m.data[3] = data[i * 4 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 4] = m.data[2]; + data[i * 4 + j + 5] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 4]; + m.data[4] = data[i * 4 + j + 5]; + m.data[5] = data[i * 4 + j + 6]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 4] = m.data[3]; + data[i * 4 + j + 5] = m.data[4]; + data[i * 4 + j + 6] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + m.data[4] = data[i * 4 + j + 4]; + m.data[5] = data[i * 4 + j + 5]; + m.data[6] = data[i * 4 + j + 6]; + m.data[7] = data[i * 4 + j + 7]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + data[i * 4 + j + 4] = m.data[4]; + data[i * 4 + j + 5] = m.data[5]; + data[i * 4 + j + 6] = m.data[6]; + data[i * 4 + j + 7] = m.data[7]; + + return *this; + } + + /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-1 matrix with a 2-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) + , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2)); + } + + /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-2 matrix with a 2-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) + , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1)); + } + + /// Forms a 2-by-4 matrix by horizontally concatenating a 2-by-3 matrix with a 2-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) + , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0)); + } + + /// Forms a 2-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 1-by-4 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); + } + + /// Concatenates this matrix with a a 1-by-4 matrix to form a 3-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Concatenates this matrix with a a 2-by-4 matrix to form a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Forms a 2-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Matrix const & B, + Element C, Matrix const & D) { + return Matrix( + A, B.at(0, 0), B.at(0, 1), B.at(0, 2) + , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) + ); + } + + /// Forms a 2-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) + , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) + ); + } + + /// Forms a 2-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Element B, + Matrix const & C, Element D) { + return Matrix( + A.at(0, 0), A.at(0, 1), A.at(0, 2), B + , C.at(0, 0), C.at(0, 1), C.at(0, 2), D + ); + } + + /// Elementwise add operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + result.data[3] = data[3] + rhs.data[3]; + + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + result.data[6] = data[6] + rhs.data[6]; + result.data[7] = data[7] + rhs.data[7]; + + return result; + } + + /// Elementwise add operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + data[3] += rhs.data[3]; + + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + data[6] += rhs.data[6]; + data[7] += rhs.data[7]; + + return *this; + } + + /// Elementwise subtract operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + result.data[3] = data[3] - rhs.data[3]; + + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + result.data[6] = data[6] - rhs.data[6]; + result.data[7] = data[7] - rhs.data[7]; + + return result; + } + + /// Elementwise subtract operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + data[3] -= rhs.data[3]; + + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + data[6] -= rhs.data[6]; + data[7] -= rhs.data[7]; + + return *this; + } + + /// Elementwise multiply operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + result.data[3] = data[3] * rhs.data[3]; + + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + result.data[6] = data[6] * rhs.data[6]; + result.data[7] = data[7] * rhs.data[7]; + + return result; + } + + /// Scalar multiply operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + result.data[3] = data[3] * s; + + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + result.data[6] = data[6] * s; + result.data[7] = data[7] * s; + + return result; + } + + /// Scalar multiply operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + data[3] *= s; + + data[4] *= s; + data[5] *= s; + data[6] *= s; + data[7] *= s; + + return *this; + } + + /// Elementwise divide operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + result.data[3] = data[3] / rhs.data[3]; + + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + result.data[6] = data[6] / rhs.data[6]; + result.data[7] = data[7] / rhs.data[7]; + + return result; + } + + /// Scalar divide operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + result.data[3] = data[3] / s; + + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + result.data[6] = data[6] / s; + result.data[7] = data[7] / s; + + return result; + } + + /// Scalar divide operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + data[3] /= s; + + data[4] /= s; + data[5] /= s; + data[6] /= s; + data[7] /= s; + + return *this; + } + + /// Elementwise divide operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (2-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + data[3] /= rhs.data[3]; + + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + data[6] /= rhs.data[6]; + data[7] /= rhs.data[7]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + m.data[6] = -m.data[6]; + m.data[7] = -m.data[7]; + + return m; + } + + /// Matrix product of size 2-by-1-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[4] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[5] * rhs.data[1]; + + // k=2 + accum.data[0] += data[2] * rhs.data[2]; + accum.data[1] += data[6] * rhs.data[2]; + + // k=3 + accum.data[0] += data[3] * rhs.data[3]; + accum.data[1] += data[7] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 2-by-1-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[4] * rhs.data[0]; + accum.data[3] += data[4] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[5] * rhs.data[2]; + accum.data[3] += data[5] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + accum.data[2] += data[6] * rhs.data[4]; + accum.data[3] += data[6] * rhs.data[5]; + + // k=3 + accum.data[0] += data[3] * rhs.data[6]; + accum.data[1] += data[3] * rhs.data[7]; + accum.data[2] += data[7] * rhs.data[6]; + accum.data[3] += data[7] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 2-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[4] * rhs.data[0]; + accum.data[4] += data[4] * rhs.data[1]; + accum.data[5] += data[4] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[5] * rhs.data[3]; + accum.data[4] += data[5] * rhs.data[4]; + accum.data[5] += data[5] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + accum.data[3] += data[6] * rhs.data[6]; + accum.data[4] += data[6] * rhs.data[7]; + accum.data[5] += data[6] * rhs.data[8]; + + // k=3 + accum.data[0] += data[3] * rhs.data[9]; + accum.data[1] += data[3] * rhs.data[10]; + accum.data[2] += data[3] * rhs.data[11]; + accum.data[3] += data[7] * rhs.data[9]; + accum.data[4] += data[7] * rhs.data[10]; + accum.data[5] += data[7] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 2-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[4] * rhs.data[0]; + accum.data[5] += data[4] * rhs.data[1]; + accum.data[6] += data[4] * rhs.data[2]; + accum.data[7] += data[4] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[5] * rhs.data[4]; + accum.data[5] += data[5] * rhs.data[5]; + accum.data[6] += data[5] * rhs.data[6]; + accum.data[7] += data[5] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + accum.data[4] += data[6] * rhs.data[8]; + accum.data[5] += data[6] * rhs.data[9]; + accum.data[6] += data[6] * rhs.data[10]; + accum.data[7] += data[6] * rhs.data[11]; + + // k=3 + accum.data[0] += data[3] * rhs.data[12]; + accum.data[1] += data[3] * rhs.data[13]; + accum.data[2] += data[3] * rhs.data[14]; + accum.data[3] += data[3] * rhs.data[15]; + accum.data[4] += data[7] * rhs.data[12]; + accum.data[5] += data[7] * rhs.data[13]; + accum.data[6] += data[7] * rhs.data[14]; + accum.data[7] += data[7] * rhs.data[15]; + + return accum; + } + + /// Matrix product of size 2-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 2-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + accum += data[6]; + accum += data[7]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + accum += data[6] * data[6]; + accum += data[7] * data[7]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[5]; + + return accum; + } + +}; + +/// Template alias for 2-by-4 matrix +template +using Matrix2x4 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix2x4 make_Matrix2x4( + Element _0_0, Element _0_1, Element _0_2, Element _0_3, + Element _1_0, Element _1_1, Element _1_2, Element _1_3 +) { + return Matrix2x4( + _0_0, _0_1, _0_2, _0_3, + _1_0, _1_1, _1_2, _1_3 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 3-by-1 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 3; + + /// Number of columns in matrix + static int const kColumns = 1; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 3; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 3-by-1 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 3-by-1 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, + Element _1_0, + Element _2_0 + ) { + + data[0] = _0_0; + data[1] = _1_0; + data[2] = _2_0; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[1] = data[1]; + mt.data[2] = data[2]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 1 + j + 0]; + m.data[1] = data[i * 1 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 1 + j + 0] = m.data[0]; + data[i * 1 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 1 + j + 0]; + m.data[1] = data[i * 1 + j + 1]; + m.data[2] = data[i * 1 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 1 + j + 0] = m.data[0]; + data[i * 1 + j + 1] = m.data[1]; + data[i * 1 + j + 2] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_3x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_3x1(v, 0, j); + } + + /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 3-by-3 matrix to form a 3-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 3-by-1 matrix by vertically concatenating an Element with a 2-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Element upper, Matrix const & lower) { + return Matrix( + upper + , lower.at(0, 0) + , lower.at(1, 0)); + } + + /// Forms a 3-by-1 matrix by vertically concatenating a 2-by-1 matrix with an Element + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Element lower) { + return Matrix( + upper.at(0, 0) + , upper.at(1, 0) + , lower); + } + + /// Concatenates this matrix with a an Element to form a 4-by-1 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Element rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Elementwise add operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + + result.data[1] = data[1] + rhs.data[1]; + + result.data[2] = data[2] + rhs.data[2]; + + return result; + } + + /// Elementwise add operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + + data[1] += rhs.data[1]; + + data[2] += rhs.data[2]; + + return *this; + } + + /// Elementwise subtract operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + + result.data[1] = data[1] - rhs.data[1]; + + result.data[2] = data[2] - rhs.data[2]; + + return result; + } + + /// Elementwise subtract operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + + data[1] -= rhs.data[1]; + + data[2] -= rhs.data[2]; + + return *this; + } + + /// Elementwise multiply operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + + result.data[1] = data[1] * rhs.data[1]; + + result.data[2] = data[2] * rhs.data[2]; + + return result; + } + + /// Scalar multiply operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + + result.data[1] = data[1] * s; + + result.data[2] = data[2] * s; + + return result; + } + + /// Scalar multiply operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + + data[1] *= s; + + data[2] *= s; + + return *this; + } + + /// Elementwise divide operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + + result.data[1] = data[1] / rhs.data[1]; + + result.data[2] = data[2] / rhs.data[2]; + + return result; + } + + /// Scalar divide operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + + result.data[1] = data[1] / s; + + result.data[2] = data[2] / s; + + return result; + } + + /// Scalar divide operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + + data[1] /= s; + + data[2] /= s; + + return *this; + } + + /// Elementwise divide operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (3-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + + data[1] /= rhs.data[1]; + + data[2] /= rhs.data[2]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + + return m; + } + + /// Matrix product of size 3-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[1] * rhs.data[0]; + accum.data[2] += data[2] * rhs.data[0]; + + return accum; + } + + /// Matrix product of size 3-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 3-by-2-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[1] * rhs.data[0]; + accum.data[3] += data[1] * rhs.data[1]; + accum.data[4] += data[2] * rhs.data[0]; + accum.data[5] += data[2] * rhs.data[1]; + + return accum; + } + + /// Matrix product of size 3-by-2-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-3-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[1] * rhs.data[0]; + accum.data[4] += data[1] * rhs.data[1]; + accum.data[5] += data[1] * rhs.data[2]; + accum.data[6] += data[2] * rhs.data[0]; + accum.data[7] += data[2] * rhs.data[1]; + accum.data[8] += data[2] * rhs.data[2]; + + return accum; + } + + /// Matrix product of size 3-by-3-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-4-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[1] * rhs.data[0]; + accum.data[5] += data[1] * rhs.data[1]; + accum.data[6] += data[1] * rhs.data[2]; + accum.data[7] += data[1] * rhs.data[3]; + accum.data[8] += data[2] * rhs.data[0]; + accum.data[9] += data[2] * rhs.data[1]; + accum.data[10] += data[2] * rhs.data[2]; + accum.data[11] += data[2] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 3-by-4-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Dot product of vectors with extent 3 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + return accum; + } + + /// Dot product of vectors with extent 3 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + return accum; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + + return accum; + } + + /// Cross product + CUTLASS_HOST_DEVICE + Matrix cross(Matrix const &rhs) const { + return Matrix( + data[1] * rhs.data[2] - data[2] * rhs.data[1], + data[0] * rhs.data[2] - data[2] * rhs.data[1], + data[0] * rhs.data[1] - data[1] * rhs.data[0] + ); + } + +}; + +/// Template alias for 3-by-1 matrix +template +using Matrix3x1 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix3x1 make_Matrix3x1( + Element _0_0, + Element _1_0, + Element _2_0 +) { + return Matrix3x1( + _0_0, + _1_0, + _2_0 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 3-by-2 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 3; + + /// Number of columns in matrix + static int const kColumns = 2; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 6; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 3-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 3-by-2 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, + Element _1_0, Element _1_1, + Element _2_0, Element _2_1 + ) { + + data[0] = _0_0; data[1] = _0_1; + data[2] = _1_0; data[3] = _1_1; + data[4] = _2_0; data[5] = _2_1; + } + + /// Constucts a 3-by-2 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1, + Matrix const &row_2 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_1.data[0]; + data[3] = row_1.data[1]; + data[4] = row_2.data[0]; + data[5] = row_2.data[1]; + } + + /// Static method to construct a 3-by-2 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_0.data[1]; + result.data[3] = column_1.data[1]; + result.data[4] = column_0.data[2]; + result.data[5] = column_1.data[2]; + return result; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[4] = diag.data[1]; + m.data[8] = diag.data[2]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[4] = diag.data[1]; + m.data[8] = diag.data[2]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[4]; + diag.data[2] = data[8]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[3] = data[1]; + mt.data[1] = data[2]; + mt.data[4] = data[3]; + mt.data[2] = data[4]; + mt.data[5] = data[5]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x2(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x2(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 2] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + m.data[2] = data[i * 2 + j + 2]; + m.data[3] = data[i * 2 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + data[i * 2 + j + 2] = m.data[2]; + data[i * 2 + j + 3] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 2]; + m.data[2] = data[i * 2 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 2] = m.data[1]; + data[i * 2 + j + 4] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_3x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_3x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + m.data[2] = data[i * 2 + j + 2]; + m.data[3] = data[i * 2 + j + 3]; + m.data[4] = data[i * 2 + j + 4]; + m.data[5] = data[i * 2 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + data[i * 2 + j + 2] = m.data[2]; + data[i * 2 + j + 3] = m.data[3]; + data[i * 2 + j + 4] = m.data[4]; + data[i * 2 + j + 5] = m.data[5]; + + return *this; + } + + /// Forms a 3-by-2 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0) + , lhs.at(1, 0), rhs.at(1, 0) + , lhs.at(2, 0), rhs.at(2, 0)); + } + + /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 3-by-2 matrix to form a 3-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 3-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 2-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1) + , lower.at(0, 0), lower.at(0, 1) + , lower.at(1, 0), lower.at(1, 1)); + } + + /// Forms a 3-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 1-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1) + , upper.at(1, 0), upper.at(1, 1) + , lower.at(0, 0), lower.at(0, 1)); + } + + /// Concatenates this matrix with a a 1-by-2 matrix to form a 4-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Forms a 3-by-2 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Element B, + Matrix const & C, Matrix const & D) { + return Matrix( + A, B + , C.at(0, 0), D.at(0, 0) + , C.at(1, 0), D.at(1, 0) + ); + } + + /// Forms a 3-by-2 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Element C, Element D) { + return Matrix( + A.at(0, 0), B.at(0, 0) + , A.at(1, 0), B.at(1, 0) + , C, D + ); + } + + /// Elementwise add operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + + result.data[2] = data[2] + rhs.data[2]; + result.data[3] = data[3] + rhs.data[3]; + + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + + return result; + } + + /// Elementwise add operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + + data[2] += rhs.data[2]; + data[3] += rhs.data[3]; + + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + + return *this; + } + + /// Elementwise subtract operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + + result.data[2] = data[2] - rhs.data[2]; + result.data[3] = data[3] - rhs.data[3]; + + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + + return result; + } + + /// Elementwise subtract operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + + data[2] -= rhs.data[2]; + data[3] -= rhs.data[3]; + + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + + return *this; + } + + /// Elementwise multiply operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + + result.data[2] = data[2] * rhs.data[2]; + result.data[3] = data[3] * rhs.data[3]; + + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + + return result; + } + + /// Scalar multiply operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + + result.data[2] = data[2] * s; + result.data[3] = data[3] * s; + + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + + return result; + } + + /// Scalar multiply operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + + data[2] *= s; + data[3] *= s; + + data[4] *= s; + data[5] *= s; + + return *this; + } + + /// Elementwise divide operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + + result.data[2] = data[2] / rhs.data[2]; + result.data[3] = data[3] / rhs.data[3]; + + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + + return result; + } + + /// Scalar divide operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + + result.data[2] = data[2] / s; + result.data[3] = data[3] / s; + + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + + return result; + } + + /// Scalar divide operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + + data[2] /= s; + data[3] /= s; + + data[4] /= s; + data[5] /= s; + + return *this; + } + + /// Elementwise divide operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (3-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + + data[2] /= rhs.data[2]; + data[3] /= rhs.data[3]; + + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + + return m; + } + + /// Matrix product of size 3-by-1-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[2] * rhs.data[0]; + accum.data[2] += data[4] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[3] * rhs.data[1]; + accum.data[2] += data[5] * rhs.data[1]; + + return accum; + } + + /// Matrix product of size 3-by-1-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[2] * rhs.data[0]; + accum.data[3] += data[2] * rhs.data[1]; + accum.data[4] += data[4] * rhs.data[0]; + accum.data[5] += data[4] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[3] * rhs.data[2]; + accum.data[3] += data[3] * rhs.data[3]; + accum.data[4] += data[5] * rhs.data[2]; + accum.data[5] += data[5] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 3-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 3-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[2] * rhs.data[0]; + accum.data[4] += data[2] * rhs.data[1]; + accum.data[5] += data[2] * rhs.data[2]; + accum.data[6] += data[4] * rhs.data[0]; + accum.data[7] += data[4] * rhs.data[1]; + accum.data[8] += data[4] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[3] * rhs.data[3]; + accum.data[4] += data[3] * rhs.data[4]; + accum.data[5] += data[3] * rhs.data[5]; + accum.data[6] += data[5] * rhs.data[3]; + accum.data[7] += data[5] * rhs.data[4]; + accum.data[8] += data[5] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 3-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[2] * rhs.data[0]; + accum.data[5] += data[2] * rhs.data[1]; + accum.data[6] += data[2] * rhs.data[2]; + accum.data[7] += data[2] * rhs.data[3]; + accum.data[8] += data[4] * rhs.data[0]; + accum.data[9] += data[4] * rhs.data[1]; + accum.data[10] += data[4] * rhs.data[2]; + accum.data[11] += data[4] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[3] * rhs.data[4]; + accum.data[5] += data[3] * rhs.data[5]; + accum.data[6] += data[3] * rhs.data[6]; + accum.data[7] += data[3] * rhs.data[7]; + accum.data[8] += data[5] * rhs.data[4]; + accum.data[9] += data[5] * rhs.data[5]; + accum.data[10] += data[5] * rhs.data[6]; + accum.data[11] += data[5] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 3-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[3]; + + return accum; + } + +}; + +/// Template alias for 3-by-2 matrix +template +using Matrix3x2 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix3x2 make_Matrix3x2( + Element _0_0, Element _0_1, + Element _1_0, Element _1_1, + Element _2_0, Element _2_1 +) { + return Matrix3x2( + _0_0, _0_1, + _1_0, _1_1, + _2_0, _2_1 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 3-by-3 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 3; + + /// Number of columns in matrix + static int const kColumns = 3; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 9; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 3-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 3-by-3 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2, + Element _1_0, Element _1_1, Element _1_2, + Element _2_0, Element _2_1, Element _2_2 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; + data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; + data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; + } + + /// Constucts a 3-by-3 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1, + Matrix const &row_2 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_0.data[2]; + data[3] = row_1.data[0]; + data[4] = row_1.data[1]; + data[5] = row_1.data[2]; + data[6] = row_2.data[0]; + data[7] = row_2.data[1]; + data[8] = row_2.data[2]; + } + + /// Static method to construct a 3-by-3 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1, + Matrix const &column_2 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_2.data[0]; + result.data[3] = column_0.data[1]; + result.data[4] = column_1.data[1]; + result.data[5] = column_2.data[1]; + result.data[6] = column_0.data[2]; + result.data[7] = column_1.data[2]; + result.data[8] = column_2.data[2]; + return result; + } + + /// Constructs an identity matrix + CUTLASS_HOST_DEVICE + static Matrix identity() { + Matrix m; + + m.data[0] = Element(1); + m.data[4] = Element(1); + m.data[8] = Element(1); + + return m; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + m.data[6] = s; + m.data[7] = s; + m.data[8] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[4] = diag.data[1]; + m.data[8] = diag.data[2]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[4] = diag.data[1]; + m.data[8] = diag.data[2]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[4]; + diag.data[2] = data[8]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[3] = data[1]; + mt.data[6] = data[2]; + mt.data[1] = data[3]; + mt.data[4] = data[4]; + mt.data[7] = data[5]; + mt.data[2] = data[6]; + mt.data[5] = data[7]; + mt.data[8] = data[8]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x3(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x3(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 3] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 3]; + m.data[3] = data[i * 3 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 3] = m.data[2]; + data[i * 3 + j + 4] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + m.data[3] = data[i * 3 + j + 3]; + m.data[4] = data[i * 3 + j + 4]; + m.data[5] = data[i * 3 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + data[i * 3 + j + 3] = m.data[3]; + data[i * 3 + j + 4] = m.data[4]; + data[i * 3 + j + 5] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 3]; + m.data[2] = data[i * 3 + j + 6]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 3] = m.data[1]; + data[i * 3 + j + 6] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_3x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_3x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 3]; + m.data[3] = data[i * 3 + j + 4]; + m.data[4] = data[i * 3 + j + 6]; + m.data[5] = data[i * 3 + j + 7]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 3] = m.data[2]; + data[i * 3 + j + 4] = m.data[3]; + data[i * 3 + j + 6] = m.data[4]; + data[i * 3 + j + 7] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + m.data[3] = data[i * 3 + j + 3]; + m.data[4] = data[i * 3 + j + 4]; + m.data[5] = data[i * 3 + j + 5]; + m.data[6] = data[i * 3 + j + 6]; + m.data[7] = data[i * 3 + j + 7]; + m.data[8] = data[i * 3 + j + 8]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + data[i * 3 + j + 3] = m.data[3]; + data[i * 3 + j + 4] = m.data[4]; + data[i * 3 + j + 5] = m.data[5]; + data[i * 3 + j + 6] = m.data[6]; + data[i * 3 + j + 7] = m.data[7]; + data[i * 3 + j + 8] = m.data[8]; + + return *this; + } + + /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) + , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) + , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1)); + } + + /// Forms a 3-by-3 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) + , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) + , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0)); + } + + /// Concatenates this matrix with a a 3-by-1 matrix to form a 3-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 3-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 2-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) + , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); + } + + /// Forms a 3-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 1-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) + , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); + } + + /// Concatenates this matrix with a a 1-by-3 matrix to form a 4-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Forms a 3-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A, B.at(0, 0), B.at(0, 1) + , C.at(0, 0), D.at(0, 0), D.at(0, 1) + , C.at(1, 0), D.at(1, 0), D.at(1, 1) + ); + } + + /// Forms a 3-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Element B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B + , C.at(0, 0), C.at(0, 1), D.at(0, 0) + , C.at(1, 0), C.at(1, 1), D.at(1, 0) + ); + } + + /// Forms a 3-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Element C, Matrix const & D) { + return Matrix( + A.at(0, 0), B.at(0, 0), B.at(0, 1) + , A.at(1, 0), B.at(1, 0), B.at(1, 1) + , C, D.at(0, 0), D.at(0, 1) + ); + } + + /// Forms a 3-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Element D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0) + , A.at(1, 0), A.at(1, 1), B.at(1, 0) + , C.at(0, 0), C.at(0, 1), D + ); + } + + /// Elementwise add operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + + result.data[3] = data[3] + rhs.data[3]; + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + + result.data[6] = data[6] + rhs.data[6]; + result.data[7] = data[7] + rhs.data[7]; + result.data[8] = data[8] + rhs.data[8]; + + return result; + } + + /// Elementwise add operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + + data[3] += rhs.data[3]; + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + + data[6] += rhs.data[6]; + data[7] += rhs.data[7]; + data[8] += rhs.data[8]; + + return *this; + } + + /// Elementwise subtract operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + + result.data[3] = data[3] - rhs.data[3]; + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + + result.data[6] = data[6] - rhs.data[6]; + result.data[7] = data[7] - rhs.data[7]; + result.data[8] = data[8] - rhs.data[8]; + + return result; + } + + /// Elementwise subtract operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + + data[3] -= rhs.data[3]; + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + + data[6] -= rhs.data[6]; + data[7] -= rhs.data[7]; + data[8] -= rhs.data[8]; + + return *this; + } + + /// Elementwise multiply operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + + result.data[3] = data[3] * rhs.data[3]; + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + + result.data[6] = data[6] * rhs.data[6]; + result.data[7] = data[7] * rhs.data[7]; + result.data[8] = data[8] * rhs.data[8]; + + return result; + } + + /// Scalar multiply operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + + result.data[3] = data[3] * s; + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + + result.data[6] = data[6] * s; + result.data[7] = data[7] * s; + result.data[8] = data[8] * s; + + return result; + } + + /// Scalar multiply operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + + data[3] *= s; + data[4] *= s; + data[5] *= s; + + data[6] *= s; + data[7] *= s; + data[8] *= s; + + return *this; + } + + /// Elementwise divide operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + + result.data[3] = data[3] / rhs.data[3]; + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + + result.data[6] = data[6] / rhs.data[6]; + result.data[7] = data[7] / rhs.data[7]; + result.data[8] = data[8] / rhs.data[8]; + + return result; + } + + /// Scalar divide operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + + result.data[3] = data[3] / s; + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + + result.data[6] = data[6] / s; + result.data[7] = data[7] / s; + result.data[8] = data[8] / s; + + return result; + } + + /// Scalar divide operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + + data[3] /= s; + data[4] /= s; + data[5] /= s; + + data[6] /= s; + data[7] /= s; + data[8] /= s; + + return *this; + } + + /// Elementwise divide operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (3-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + + data[3] /= rhs.data[3]; + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + + data[6] /= rhs.data[6]; + data[7] /= rhs.data[7]; + data[8] /= rhs.data[8]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + m.data[6] = -m.data[6]; + m.data[7] = -m.data[7]; + m.data[8] = -m.data[8]; + + return m; + } + + /// Matrix product of size 3-by-1-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[3] * rhs.data[0]; + accum.data[2] += data[6] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[4] * rhs.data[1]; + accum.data[2] += data[7] * rhs.data[1]; + + // k=2 + accum.data[0] += data[2] * rhs.data[2]; + accum.data[1] += data[5] * rhs.data[2]; + accum.data[2] += data[8] * rhs.data[2]; + + return accum; + } + + /// Matrix product of size 3-by-1-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[3] * rhs.data[0]; + accum.data[3] += data[3] * rhs.data[1]; + accum.data[4] += data[6] * rhs.data[0]; + accum.data[5] += data[6] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[4] * rhs.data[2]; + accum.data[3] += data[4] * rhs.data[3]; + accum.data[4] += data[7] * rhs.data[2]; + accum.data[5] += data[7] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + accum.data[2] += data[5] * rhs.data[4]; + accum.data[3] += data[5] * rhs.data[5]; + accum.data[4] += data[8] * rhs.data[4]; + accum.data[5] += data[8] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 3-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[3] * rhs.data[0]; + accum.data[4] += data[3] * rhs.data[1]; + accum.data[5] += data[3] * rhs.data[2]; + accum.data[6] += data[6] * rhs.data[0]; + accum.data[7] += data[6] * rhs.data[1]; + accum.data[8] += data[6] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[4] * rhs.data[3]; + accum.data[4] += data[4] * rhs.data[4]; + accum.data[5] += data[4] * rhs.data[5]; + accum.data[6] += data[7] * rhs.data[3]; + accum.data[7] += data[7] * rhs.data[4]; + accum.data[8] += data[7] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + accum.data[3] += data[5] * rhs.data[6]; + accum.data[4] += data[5] * rhs.data[7]; + accum.data[5] += data[5] * rhs.data[8]; + accum.data[6] += data[8] * rhs.data[6]; + accum.data[7] += data[8] * rhs.data[7]; + accum.data[8] += data[8] * rhs.data[8]; + + return accum; + } + + /// Matrix product of size 3-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 3-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[3] * rhs.data[0]; + accum.data[5] += data[3] * rhs.data[1]; + accum.data[6] += data[3] * rhs.data[2]; + accum.data[7] += data[3] * rhs.data[3]; + accum.data[8] += data[6] * rhs.data[0]; + accum.data[9] += data[6] * rhs.data[1]; + accum.data[10] += data[6] * rhs.data[2]; + accum.data[11] += data[6] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[4] * rhs.data[4]; + accum.data[5] += data[4] * rhs.data[5]; + accum.data[6] += data[4] * rhs.data[6]; + accum.data[7] += data[4] * rhs.data[7]; + accum.data[8] += data[7] * rhs.data[4]; + accum.data[9] += data[7] * rhs.data[5]; + accum.data[10] += data[7] * rhs.data[6]; + accum.data[11] += data[7] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + accum.data[4] += data[5] * rhs.data[8]; + accum.data[5] += data[5] * rhs.data[9]; + accum.data[6] += data[5] * rhs.data[10]; + accum.data[7] += data[5] * rhs.data[11]; + accum.data[8] += data[8] * rhs.data[8]; + accum.data[9] += data[8] * rhs.data[9]; + accum.data[10] += data[8] * rhs.data[10]; + accum.data[11] += data[8] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 3-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + accum += data[6]; + accum += data[7]; + accum += data[8]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + accum += data[6] * data[6]; + accum += data[7] * data[7]; + accum += data[8] * data[8]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[4]; + accum += data[8]; + + return accum; + } + + /// Returns 3-by-3 rotation matrix around the X axis + CUTLASS_HOST_DEVICE + static Matrix rotation_X(Element theta) { + Matrix m = identity(); + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + m.at(1, 1) = c; + m.at(1, 2) = -s; + m.at(2, 1) = s; + m.at(2, 2) = c; + + return m; + } + + /// Returns 3-by-3 rotation matrix around the Y axis + CUTLASS_HOST_DEVICE + static Matrix rotation_Y(Element theta) { + Matrix m = identity(); + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + m.at(0, 0) = c; + m.at(2, 0) = -s; + m.at(0, 2) = s; + m.at(2, 2) = c; + + return m; + } + + /// Returns 3-by-3 rotation matrix around the Z axis + CUTLASS_HOST_DEVICE + static Matrix rotation_Z(Element theta) { + Matrix m = Matrix::identity(); + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + m.at(0, 0) = c; + m.at(0, 1) = -s; + m.at(1, 0) = s; + m.at(1, 1) = c; + + return m; + } + + /// Returns a 3-by-3 rotation matrix around a unit-length axis + CUTLASS_HOST_DEVICE + static Matrix rotation(Element theta, Matrix const &u) { + Element x = u.data[0]; + Element y = u.data[1]; + Element z = u.data[2]; + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + Element one_minus_cos = Element(1) - fast_cos(theta); + + Matrix m; + + m.set_slice3x3({ + c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, + y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, + z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos + }); + + return m; + } + + /// Returns a 3-by-3 reflection about the plane specified by the + /// unit-length normal vector n_unit + CUTLASS_HOST_DEVICE + static Matrix reflection(Matrix const &n_unit) { + + Element a = n_unit.data[0]; + Element b = n_unit.data[1]; + Element c = n_unit.data[2]; + + Matrix m = Matrix::identity(); + + m.set_slice3x3({ + Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, + Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, + Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c + }); + + return m; + } + + /// Computes the determinant of a 3-by-3 matrix + CUTLASS_HOST_DEVICE + Element determinant(Element accum = Element()) const { + + accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(2, 1), at(2, 2) }).determinant(); + accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(2, 0), at(2, 2) }).determinant(); + accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(2, 0), at(2, 1) }).determinant(); + + return accum; + } + + /// Computes the inverse of a 3-by-3 matrix given + /// the matrix's determinant + CUTLASS_HOST_DEVICE + Matrix inverse(Element det) const { + return Matrix( + at(1, 1) * at(2, 2) - at(1, 2) * at(2, 1), + at(0, 2) * at(2, 1) - at(0, 1) * at(2, 2), + at(0, 1) * at(1, 2) - at(0, 2) * at(1, 1), + + at(1, 2) * at(2, 0) - at(1, 0) * at(2, 2), + at(0, 0) * at(2, 2) - at(0, 2) * at(2, 0), + at(0, 2) * at(1, 0) - at(0, 0) * at(1, 2), + + at(1, 0) * at(2, 1) - at(1, 1) * at(2, 0), + at(0, 1) * at(2, 0) - at(0, 0) * at(2, 1), + at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0) + ) * (Element(1) / det); + } + /// Computes the inverse of a 3-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix inverse() const { + return inverse(determinant()); + } + +}; + +/// Template alias for 3-by-3 matrix +template +using Matrix3x3 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix3x3 make_Matrix3x3( + Element _0_0, Element _0_1, Element _0_2, + Element _1_0, Element _1_1, Element _1_2, + Element _2_0, Element _2_1, Element _2_2 +) { + return Matrix3x3( + _0_0, _0_1, _0_2, + _1_0, _1_1, _1_2, + _2_0, _2_1, _2_2 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 3-by-4 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 3; + + /// Number of columns in matrix + static int const kColumns = 4; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 12; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 3-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 3-by-4 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2, Element _0_3, + Element _1_0, Element _1_1, Element _1_2, Element _1_3, + Element _2_0, Element _2_1, Element _2_2, Element _2_3 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; + data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; + data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; + } + + /// Constucts a 3-by-4 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1, + Matrix const &row_2 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_0.data[2]; + data[3] = row_0.data[3]; + data[4] = row_1.data[0]; + data[5] = row_1.data[1]; + data[6] = row_1.data[2]; + data[7] = row_1.data[3]; + data[8] = row_2.data[0]; + data[9] = row_2.data[1]; + data[10] = row_2.data[2]; + data[11] = row_2.data[3]; + } + + /// Static method to construct a 3-by-4 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1, + Matrix const &column_2, + Matrix const &column_3 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_2.data[0]; + result.data[3] = column_3.data[0]; + result.data[4] = column_0.data[1]; + result.data[5] = column_1.data[1]; + result.data[6] = column_2.data[1]; + result.data[7] = column_3.data[1]; + result.data[8] = column_0.data[2]; + result.data[9] = column_1.data[2]; + result.data[10] = column_2.data[2]; + result.data[11] = column_3.data[2]; + return result; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + m.data[6] = s; + m.data[7] = s; + m.data[8] = s; + m.data[9] = s; + m.data[10] = s; + m.data[11] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[4] = diag.data[1]; + m.data[8] = diag.data[2]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[4] = diag.data[1]; + m.data[8] = diag.data[2]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[4]; + diag.data[2] = data[8]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[3] = data[1]; + mt.data[6] = data[2]; + mt.data[9] = data[3]; + mt.data[1] = data[4]; + mt.data[4] = data[5]; + mt.data[7] = data[6]; + mt.data[10] = data[7]; + mt.data[2] = data[8]; + mt.data[5] = data[9]; + mt.data[8] = data[10]; + mt.data[11] = data[11]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 3 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x4(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x4(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 4] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 4]; + m.data[3] = data[i * 4 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 4] = m.data[2]; + data[i * 4 + j + 5] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 4]; + m.data[4] = data[i * 4 + j + 5]; + m.data[5] = data[i * 4 + j + 6]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 4] = m.data[3]; + data[i * 4 + j + 5] = m.data[4]; + data[i * 4 + j + 6] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + m.data[4] = data[i * 4 + j + 4]; + m.data[5] = data[i * 4 + j + 5]; + m.data[6] = data[i * 4 + j + 6]; + m.data[7] = data[i * 4 + j + 7]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + data[i * 4 + j + 4] = m.data[4]; + data[i * 4 + j + 5] = m.data[5]; + data[i * 4 + j + 6] = m.data[6]; + data[i * 4 + j + 7] = m.data[7]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 4]; + m.data[2] = data[i * 4 + j + 8]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 4] = m.data[1]; + data[i * 4 + j + 8] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_3x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_3x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 4]; + m.data[3] = data[i * 4 + j + 5]; + m.data[4] = data[i * 4 + j + 8]; + m.data[5] = data[i * 4 + j + 9]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 4] = m.data[2]; + data[i * 4 + j + 5] = m.data[3]; + data[i * 4 + j + 8] = m.data[4]; + data[i * 4 + j + 9] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 4]; + m.data[4] = data[i * 4 + j + 5]; + m.data[5] = data[i * 4 + j + 6]; + m.data[6] = data[i * 4 + j + 8]; + m.data[7] = data[i * 4 + j + 9]; + m.data[8] = data[i * 4 + j + 10]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 4] = m.data[3]; + data[i * 4 + j + 5] = m.data[4]; + data[i * 4 + j + 6] = m.data[5]; + data[i * 4 + j + 8] = m.data[6]; + data[i * 4 + j + 9] = m.data[7]; + data[i * 4 + j + 10] = m.data[8]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + m.data[4] = data[i * 4 + j + 4]; + m.data[5] = data[i * 4 + j + 5]; + m.data[6] = data[i * 4 + j + 6]; + m.data[7] = data[i * 4 + j + 7]; + m.data[8] = data[i * 4 + j + 8]; + m.data[9] = data[i * 4 + j + 9]; + m.data[10] = data[i * 4 + j + 10]; + m.data[11] = data[i * 4 + j + 11]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + data[i * 4 + j + 4] = m.data[4]; + data[i * 4 + j + 5] = m.data[5]; + data[i * 4 + j + 6] = m.data[6]; + data[i * 4 + j + 7] = m.data[7]; + data[i * 4 + j + 8] = m.data[8]; + data[i * 4 + j + 9] = m.data[9]; + data[i * 4 + j + 10] = m.data[10]; + data[i * 4 + j + 11] = m.data[11]; + + return *this; + } + + /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-1 matrix with a 3-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) + , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) + , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2)); + } + + /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-2 matrix with a 3-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) + , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) + , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1)); + } + + /// Forms a 3-by-4 matrix by horizontally concatenating a 3-by-3 matrix with a 3-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) + , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) + , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0)); + } + + /// Forms a 3-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 2-by-4 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) + , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); + } + + /// Forms a 3-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 1-by-4 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) + , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); + } + + /// Concatenates this matrix with a a 1-by-4 matrix to form a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix vcat(Matrix const & rhs) const { + return Matrix::vcat(*this, rhs); + } + + /// Forms a 3-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A, B.at(0, 0), B.at(0, 1), B.at(0, 2) + , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) + , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) + ); + } + + /// Forms a 3-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) + , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) + , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) + ); + } + + /// Forms a 3-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Element B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), A.at(0, 2), B + , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) + , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) + ); + } + + /// Forms a 3-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Element C, Matrix const & D) { + return Matrix( + A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) + , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) + , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) + ); + } + + /// Forms a 3-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) + , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) + , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) + ); + } + + /// Forms a 3-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Element D) { + return Matrix( + A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) + , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) + , C.at(0, 0), C.at(0, 1), C.at(0, 2), D + ); + } + + /// Elementwise add operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + result.data[3] = data[3] + rhs.data[3]; + + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + result.data[6] = data[6] + rhs.data[6]; + result.data[7] = data[7] + rhs.data[7]; + + result.data[8] = data[8] + rhs.data[8]; + result.data[9] = data[9] + rhs.data[9]; + result.data[10] = data[10] + rhs.data[10]; + result.data[11] = data[11] + rhs.data[11]; + + return result; + } + + /// Elementwise add operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + data[3] += rhs.data[3]; + + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + data[6] += rhs.data[6]; + data[7] += rhs.data[7]; + + data[8] += rhs.data[8]; + data[9] += rhs.data[9]; + data[10] += rhs.data[10]; + data[11] += rhs.data[11]; + + return *this; + } + + /// Elementwise subtract operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + result.data[3] = data[3] - rhs.data[3]; + + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + result.data[6] = data[6] - rhs.data[6]; + result.data[7] = data[7] - rhs.data[7]; + + result.data[8] = data[8] - rhs.data[8]; + result.data[9] = data[9] - rhs.data[9]; + result.data[10] = data[10] - rhs.data[10]; + result.data[11] = data[11] - rhs.data[11]; + + return result; + } + + /// Elementwise subtract operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + data[3] -= rhs.data[3]; + + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + data[6] -= rhs.data[6]; + data[7] -= rhs.data[7]; + + data[8] -= rhs.data[8]; + data[9] -= rhs.data[9]; + data[10] -= rhs.data[10]; + data[11] -= rhs.data[11]; + + return *this; + } + + /// Elementwise multiply operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + result.data[3] = data[3] * rhs.data[3]; + + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + result.data[6] = data[6] * rhs.data[6]; + result.data[7] = data[7] * rhs.data[7]; + + result.data[8] = data[8] * rhs.data[8]; + result.data[9] = data[9] * rhs.data[9]; + result.data[10] = data[10] * rhs.data[10]; + result.data[11] = data[11] * rhs.data[11]; + + return result; + } + + /// Scalar multiply operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + result.data[3] = data[3] * s; + + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + result.data[6] = data[6] * s; + result.data[7] = data[7] * s; + + result.data[8] = data[8] * s; + result.data[9] = data[9] * s; + result.data[10] = data[10] * s; + result.data[11] = data[11] * s; + + return result; + } + + /// Scalar multiply operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + data[3] *= s; + + data[4] *= s; + data[5] *= s; + data[6] *= s; + data[7] *= s; + + data[8] *= s; + data[9] *= s; + data[10] *= s; + data[11] *= s; + + return *this; + } + + /// Elementwise divide operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + result.data[3] = data[3] / rhs.data[3]; + + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + result.data[6] = data[6] / rhs.data[6]; + result.data[7] = data[7] / rhs.data[7]; + + result.data[8] = data[8] / rhs.data[8]; + result.data[9] = data[9] / rhs.data[9]; + result.data[10] = data[10] / rhs.data[10]; + result.data[11] = data[11] / rhs.data[11]; + + return result; + } + + /// Scalar divide operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + result.data[3] = data[3] / s; + + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + result.data[6] = data[6] / s; + result.data[7] = data[7] / s; + + result.data[8] = data[8] / s; + result.data[9] = data[9] / s; + result.data[10] = data[10] / s; + result.data[11] = data[11] / s; + + return result; + } + + /// Scalar divide operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + data[3] /= s; + + data[4] /= s; + data[5] /= s; + data[6] /= s; + data[7] /= s; + + data[8] /= s; + data[9] /= s; + data[10] /= s; + data[11] /= s; + + return *this; + } + + /// Elementwise divide operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (3-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + data[3] /= rhs.data[3]; + + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + data[6] /= rhs.data[6]; + data[7] /= rhs.data[7]; + + data[8] /= rhs.data[8]; + data[9] /= rhs.data[9]; + data[10] /= rhs.data[10]; + data[11] /= rhs.data[11]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + m.data[6] = -m.data[6]; + m.data[7] = -m.data[7]; + m.data[8] = -m.data[8]; + m.data[9] = -m.data[9]; + m.data[10] = -m.data[10]; + m.data[11] = -m.data[11]; + + return m; + } + + /// Matrix product of size 3-by-1-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[4] * rhs.data[0]; + accum.data[2] += data[8] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[5] * rhs.data[1]; + accum.data[2] += data[9] * rhs.data[1]; + + // k=2 + accum.data[0] += data[2] * rhs.data[2]; + accum.data[1] += data[6] * rhs.data[2]; + accum.data[2] += data[10] * rhs.data[2]; + + // k=3 + accum.data[0] += data[3] * rhs.data[3]; + accum.data[1] += data[7] * rhs.data[3]; + accum.data[2] += data[11] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 3-by-1-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[4] * rhs.data[0]; + accum.data[3] += data[4] * rhs.data[1]; + accum.data[4] += data[8] * rhs.data[0]; + accum.data[5] += data[8] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[5] * rhs.data[2]; + accum.data[3] += data[5] * rhs.data[3]; + accum.data[4] += data[9] * rhs.data[2]; + accum.data[5] += data[9] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + accum.data[2] += data[6] * rhs.data[4]; + accum.data[3] += data[6] * rhs.data[5]; + accum.data[4] += data[10] * rhs.data[4]; + accum.data[5] += data[10] * rhs.data[5]; + + // k=3 + accum.data[0] += data[3] * rhs.data[6]; + accum.data[1] += data[3] * rhs.data[7]; + accum.data[2] += data[7] * rhs.data[6]; + accum.data[3] += data[7] * rhs.data[7]; + accum.data[4] += data[11] * rhs.data[6]; + accum.data[5] += data[11] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 3-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[4] * rhs.data[0]; + accum.data[4] += data[4] * rhs.data[1]; + accum.data[5] += data[4] * rhs.data[2]; + accum.data[6] += data[8] * rhs.data[0]; + accum.data[7] += data[8] * rhs.data[1]; + accum.data[8] += data[8] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[5] * rhs.data[3]; + accum.data[4] += data[5] * rhs.data[4]; + accum.data[5] += data[5] * rhs.data[5]; + accum.data[6] += data[9] * rhs.data[3]; + accum.data[7] += data[9] * rhs.data[4]; + accum.data[8] += data[9] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + accum.data[3] += data[6] * rhs.data[6]; + accum.data[4] += data[6] * rhs.data[7]; + accum.data[5] += data[6] * rhs.data[8]; + accum.data[6] += data[10] * rhs.data[6]; + accum.data[7] += data[10] * rhs.data[7]; + accum.data[8] += data[10] * rhs.data[8]; + + // k=3 + accum.data[0] += data[3] * rhs.data[9]; + accum.data[1] += data[3] * rhs.data[10]; + accum.data[2] += data[3] * rhs.data[11]; + accum.data[3] += data[7] * rhs.data[9]; + accum.data[4] += data[7] * rhs.data[10]; + accum.data[5] += data[7] * rhs.data[11]; + accum.data[6] += data[11] * rhs.data[9]; + accum.data[7] += data[11] * rhs.data[10]; + accum.data[8] += data[11] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 3-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[4] * rhs.data[0]; + accum.data[5] += data[4] * rhs.data[1]; + accum.data[6] += data[4] * rhs.data[2]; + accum.data[7] += data[4] * rhs.data[3]; + accum.data[8] += data[8] * rhs.data[0]; + accum.data[9] += data[8] * rhs.data[1]; + accum.data[10] += data[8] * rhs.data[2]; + accum.data[11] += data[8] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[5] * rhs.data[4]; + accum.data[5] += data[5] * rhs.data[5]; + accum.data[6] += data[5] * rhs.data[6]; + accum.data[7] += data[5] * rhs.data[7]; + accum.data[8] += data[9] * rhs.data[4]; + accum.data[9] += data[9] * rhs.data[5]; + accum.data[10] += data[9] * rhs.data[6]; + accum.data[11] += data[9] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + accum.data[4] += data[6] * rhs.data[8]; + accum.data[5] += data[6] * rhs.data[9]; + accum.data[6] += data[6] * rhs.data[10]; + accum.data[7] += data[6] * rhs.data[11]; + accum.data[8] += data[10] * rhs.data[8]; + accum.data[9] += data[10] * rhs.data[9]; + accum.data[10] += data[10] * rhs.data[10]; + accum.data[11] += data[10] * rhs.data[11]; + + // k=3 + accum.data[0] += data[3] * rhs.data[12]; + accum.data[1] += data[3] * rhs.data[13]; + accum.data[2] += data[3] * rhs.data[14]; + accum.data[3] += data[3] * rhs.data[15]; + accum.data[4] += data[7] * rhs.data[12]; + accum.data[5] += data[7] * rhs.data[13]; + accum.data[6] += data[7] * rhs.data[14]; + accum.data[7] += data[7] * rhs.data[15]; + accum.data[8] += data[11] * rhs.data[12]; + accum.data[9] += data[11] * rhs.data[13]; + accum.data[10] += data[11] * rhs.data[14]; + accum.data[11] += data[11] * rhs.data[15]; + + return accum; + } + + /// Matrix product of size 3-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 3-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + accum += data[6]; + accum += data[7]; + accum += data[8]; + accum += data[9]; + accum += data[10]; + accum += data[11]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + accum += data[6] * data[6]; + accum += data[7] * data[7]; + accum += data[8] * data[8]; + accum += data[9] * data[9]; + accum += data[10] * data[10]; + accum += data[11] * data[11]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[5]; + accum += data[10]; + + return accum; + } + +}; + +/// Template alias for 3-by-4 matrix +template +using Matrix3x4 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix3x4 make_Matrix3x4( + Element _0_0, Element _0_1, Element _0_2, Element _0_3, + Element _1_0, Element _1_1, Element _1_2, Element _1_3, + Element _2_0, Element _2_1, Element _2_2, Element _2_3 +) { + return Matrix3x4( + _0_0, _0_1, _0_2, _0_3, + _1_0, _1_1, _1_2, _1_3, + _2_0, _2_1, _2_2, _2_3 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 4-by-1 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 4; + + /// Number of columns in matrix + static int const kColumns = 1; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 4; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 4-by-1 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 4-by-1 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, + Element _1_0, + Element _2_0, + Element _3_0 + ) { + + data[0] = _0_0; + data[1] = _1_0; + data[2] = _2_0; + data[3] = _3_0; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[1] = data[1]; + mt.data[2] = data[2]; + mt.data[3] = data[3]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 1 + j + 0]; + m.data[1] = data[i * 1 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 1 + j + 0] = m.data[0]; + data[i * 1 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 1 + j + 0]; + m.data[1] = data[i * 1 + j + 1]; + m.data[2] = data[i * 1 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 1 + j + 0] = m.data[0]; + data[i * 1 + j + 1] = m.data[1]; + data[i * 1 + j + 2] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 1 + j + 0]; + m.data[1] = data[i * 1 + j + 1]; + m.data[2] = data[i * 1 + j + 2]; + m.data[3] = data[i * 1 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 1 + j + 0] = m.data[0]; + data[i * 1 + j + 1] = m.data[1]; + data[i * 1 + j + 2] = m.data[2]; + data[i * 1 + j + 3] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_4x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_4x1(v, 0, j); + } + + /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 4-by-3 matrix to form a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 4-by-1 matrix by vertically concatenating an Element with a 3-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Element upper, Matrix const & lower) { + return Matrix( + upper + , lower.at(0, 0) + , lower.at(1, 0) + , lower.at(2, 0)); + } + + /// Forms a 4-by-1 matrix by vertically concatenating a 2-by-1 matrix with a 2-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0) + , upper.at(1, 0) + , lower.at(0, 0) + , lower.at(1, 0)); + } + + /// Forms a 4-by-1 matrix by vertically concatenating a 3-by-1 matrix with an Element + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Element lower) { + return Matrix( + upper.at(0, 0) + , upper.at(1, 0) + , upper.at(2, 0) + , lower); + } + + /// Elementwise add operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + + result.data[1] = data[1] + rhs.data[1]; + + result.data[2] = data[2] + rhs.data[2]; + + result.data[3] = data[3] + rhs.data[3]; + + return result; + } + + /// Elementwise add operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + + data[1] += rhs.data[1]; + + data[2] += rhs.data[2]; + + data[3] += rhs.data[3]; + + return *this; + } + + /// Elementwise subtract operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + + result.data[1] = data[1] - rhs.data[1]; + + result.data[2] = data[2] - rhs.data[2]; + + result.data[3] = data[3] - rhs.data[3]; + + return result; + } + + /// Elementwise subtract operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + + data[1] -= rhs.data[1]; + + data[2] -= rhs.data[2]; + + data[3] -= rhs.data[3]; + + return *this; + } + + /// Elementwise multiply operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + + result.data[1] = data[1] * rhs.data[1]; + + result.data[2] = data[2] * rhs.data[2]; + + result.data[3] = data[3] * rhs.data[3]; + + return result; + } + + /// Scalar multiply operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + + result.data[1] = data[1] * s; + + result.data[2] = data[2] * s; + + result.data[3] = data[3] * s; + + return result; + } + + /// Scalar multiply operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + + data[1] *= s; + + data[2] *= s; + + data[3] *= s; + + return *this; + } + + /// Elementwise divide operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + + result.data[1] = data[1] / rhs.data[1]; + + result.data[2] = data[2] / rhs.data[2]; + + result.data[3] = data[3] / rhs.data[3]; + + return result; + } + + /// Scalar divide operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + + result.data[1] = data[1] / s; + + result.data[2] = data[2] / s; + + result.data[3] = data[3] / s; + + return result; + } + + /// Scalar divide operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + + data[1] /= s; + + data[2] /= s; + + data[3] /= s; + + return *this; + } + + /// Elementwise divide operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (4-by-1) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + + data[1] /= rhs.data[1]; + + data[2] /= rhs.data[2]; + + data[3] /= rhs.data[3]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + + return m; + } + + /// Matrix product of size 4-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[1] * rhs.data[0]; + accum.data[2] += data[2] * rhs.data[0]; + accum.data[3] += data[3] * rhs.data[0]; + + return accum; + } + + /// Matrix product of size 4-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-1-by-1 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 4-by-2-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[1] * rhs.data[0]; + accum.data[3] += data[1] * rhs.data[1]; + accum.data[4] += data[2] * rhs.data[0]; + accum.data[5] += data[2] * rhs.data[1]; + accum.data[6] += data[3] * rhs.data[0]; + accum.data[7] += data[3] * rhs.data[1]; + + return accum; + } + + /// Matrix product of size 4-by-2-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-3-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[1] * rhs.data[0]; + accum.data[4] += data[1] * rhs.data[1]; + accum.data[5] += data[1] * rhs.data[2]; + accum.data[6] += data[2] * rhs.data[0]; + accum.data[7] += data[2] * rhs.data[1]; + accum.data[8] += data[2] * rhs.data[2]; + accum.data[9] += data[3] * rhs.data[0]; + accum.data[10] += data[3] * rhs.data[1]; + accum.data[11] += data[3] * rhs.data[2]; + + return accum; + } + + /// Matrix product of size 4-by-3-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-4-by-1 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[1] * rhs.data[0]; + accum.data[5] += data[1] * rhs.data[1]; + accum.data[6] += data[1] * rhs.data[2]; + accum.data[7] += data[1] * rhs.data[3]; + accum.data[8] += data[2] * rhs.data[0]; + accum.data[9] += data[2] * rhs.data[1]; + accum.data[10] += data[2] * rhs.data[2]; + accum.data[11] += data[2] * rhs.data[3]; + accum.data[12] += data[3] * rhs.data[0]; + accum.data[13] += data[3] * rhs.data[1]; + accum.data[14] += data[3] * rhs.data[2]; + accum.data[15] += data[3] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 4-by-4-by-1 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Dot product of vectors with extent 4 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + accum += data[3] * rhs.data[3]; + return accum; + } + + /// Dot product of vectors with extent 4 + CUTLASS_HOST_DEVICE + Element dot(Matrix const &rhs, Element accum = Element()) const { + + accum += data[0] * rhs.data[0]; + accum += data[1] * rhs.data[1]; + accum += data[2] * rhs.data[2]; + accum += data[3] * rhs.data[3]; + return accum; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + + return accum; + } + +}; + +/// Template alias for 4-by-1 matrix +template +using Matrix4x1 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix4x1 make_Matrix4x1( + Element _0_0, + Element _1_0, + Element _2_0, + Element _3_0 +) { + return Matrix4x1( + _0_0, + _1_0, + _2_0, + _3_0 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 4-by-2 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 4; + + /// Number of columns in matrix + static int const kColumns = 2; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 8; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 4-by-2 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 4-by-2 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, + Element _1_0, Element _1_1, + Element _2_0, Element _2_1, + Element _3_0, Element _3_1 + ) { + + data[0] = _0_0; data[1] = _0_1; + data[2] = _1_0; data[3] = _1_1; + data[4] = _2_0; data[5] = _2_1; + data[6] = _3_0; data[7] = _3_1; + } + + /// Constucts a 4-by-2 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1, + Matrix const &row_2, + Matrix const &row_3 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_1.data[0]; + data[3] = row_1.data[1]; + data[4] = row_2.data[0]; + data[5] = row_2.data[1]; + data[6] = row_3.data[0]; + data[7] = row_3.data[1]; + } + + /// Static method to construct a 4-by-2 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_0.data[1]; + result.data[3] = column_1.data[1]; + result.data[4] = column_0.data[2]; + result.data[5] = column_1.data[2]; + result.data[6] = column_0.data[3]; + result.data[7] = column_1.data[3]; + return result; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + m.data[6] = s; + m.data[7] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[5] = diag.data[1]; + m.data[10] = diag.data[2]; + m.data[15] = diag.data[3]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[5] = diag.data[1]; + m.data[10] = diag.data[2]; + m.data[15] = diag.data[3]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[5]; + diag.data[2] = data[10]; + diag.data[3] = data[15]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[4] = data[1]; + mt.data[1] = data[2]; + mt.data[5] = data[3]; + mt.data[2] = data[4]; + mt.data[6] = data[5]; + mt.data[3] = data[6]; + mt.data[7] = data[7]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x2(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x2(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 2] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + m.data[2] = data[i * 2 + j + 2]; + m.data[3] = data[i * 2 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + data[i * 2 + j + 2] = m.data[2]; + data[i * 2 + j + 3] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 2]; + m.data[2] = data[i * 2 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 2] = m.data[1]; + data[i * 2 + j + 4] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + m.data[2] = data[i * 2 + j + 2]; + m.data[3] = data[i * 2 + j + 3]; + m.data[4] = data[i * 2 + j + 4]; + m.data[5] = data[i * 2 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + data[i * 2 + j + 2] = m.data[2]; + data[i * 2 + j + 3] = m.data[3]; + data[i * 2 + j + 4] = m.data[4]; + data[i * 2 + j + 5] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 2]; + m.data[2] = data[i * 2 + j + 4]; + m.data[3] = data[i * 2 + j + 6]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 2] = m.data[1]; + data[i * 2 + j + 4] = m.data[2]; + data[i * 2 + j + 6] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_4x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_4x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 2 + j + 0]; + m.data[1] = data[i * 2 + j + 1]; + m.data[2] = data[i * 2 + j + 2]; + m.data[3] = data[i * 2 + j + 3]; + m.data[4] = data[i * 2 + j + 4]; + m.data[5] = data[i * 2 + j + 5]; + m.data[6] = data[i * 2 + j + 6]; + m.data[7] = data[i * 2 + j + 7]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 2 + j + 0] = m.data[0]; + data[i * 2 + j + 1] = m.data[1]; + data[i * 2 + j + 2] = m.data[2]; + data[i * 2 + j + 3] = m.data[3]; + data[i * 2 + j + 4] = m.data[4]; + data[i * 2 + j + 5] = m.data[5]; + data[i * 2 + j + 6] = m.data[6]; + data[i * 2 + j + 7] = m.data[7]; + + return *this; + } + + /// Forms a 4-by-2 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0) + , lhs.at(1, 0), rhs.at(1, 0) + , lhs.at(2, 0), rhs.at(2, 0) + , lhs.at(3, 0), rhs.at(3, 0)); + } + + /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Concatenates this matrix with a a 4-by-2 matrix to form a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 4-by-2 matrix by vertically concatenating a 1-by-2 matrix with a 3-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1) + , lower.at(0, 0), lower.at(0, 1) + , lower.at(1, 0), lower.at(1, 1) + , lower.at(2, 0), lower.at(2, 1)); + } + + /// Forms a 4-by-2 matrix by vertically concatenating a 2-by-2 matrix with a 2-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1) + , upper.at(1, 0), upper.at(1, 1) + , lower.at(0, 0), lower.at(0, 1) + , lower.at(1, 0), lower.at(1, 1)); + } + + /// Forms a 4-by-2 matrix by vertically concatenating a 3-by-2 matrix with a 1-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1) + , upper.at(1, 0), upper.at(1, 1) + , upper.at(2, 0), upper.at(2, 1) + , lower.at(0, 0), lower.at(0, 1)); + } + + /// Forms a 4-by-2 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Element B, + Matrix const & C, Matrix const & D) { + return Matrix( + A, B + , C.at(0, 0), D.at(0, 0) + , C.at(1, 0), D.at(1, 0) + , C.at(2, 0), D.at(2, 0) + ); + } + + /// Forms a 4-by-2 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), B.at(0, 0) + , A.at(1, 0), B.at(1, 0) + , C.at(0, 0), D.at(0, 0) + , C.at(1, 0), D.at(1, 0) + ); + } + + /// Forms a 4-by-2 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Element C, Element D) { + return Matrix( + A.at(0, 0), B.at(0, 0) + , A.at(1, 0), B.at(1, 0) + , A.at(2, 0), B.at(2, 0) + , C, D + ); + } + + /// Elementwise add operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + + result.data[2] = data[2] + rhs.data[2]; + result.data[3] = data[3] + rhs.data[3]; + + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + + result.data[6] = data[6] + rhs.data[6]; + result.data[7] = data[7] + rhs.data[7]; + + return result; + } + + /// Elementwise add operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + + data[2] += rhs.data[2]; + data[3] += rhs.data[3]; + + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + + data[6] += rhs.data[6]; + data[7] += rhs.data[7]; + + return *this; + } + + /// Elementwise subtract operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + + result.data[2] = data[2] - rhs.data[2]; + result.data[3] = data[3] - rhs.data[3]; + + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + + result.data[6] = data[6] - rhs.data[6]; + result.data[7] = data[7] - rhs.data[7]; + + return result; + } + + /// Elementwise subtract operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + + data[2] -= rhs.data[2]; + data[3] -= rhs.data[3]; + + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + + data[6] -= rhs.data[6]; + data[7] -= rhs.data[7]; + + return *this; + } + + /// Elementwise multiply operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + + result.data[2] = data[2] * rhs.data[2]; + result.data[3] = data[3] * rhs.data[3]; + + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + + result.data[6] = data[6] * rhs.data[6]; + result.data[7] = data[7] * rhs.data[7]; + + return result; + } + + /// Scalar multiply operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + + result.data[2] = data[2] * s; + result.data[3] = data[3] * s; + + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + + result.data[6] = data[6] * s; + result.data[7] = data[7] * s; + + return result; + } + + /// Scalar multiply operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + + data[2] *= s; + data[3] *= s; + + data[4] *= s; + data[5] *= s; + + data[6] *= s; + data[7] *= s; + + return *this; + } + + /// Elementwise divide operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + + result.data[2] = data[2] / rhs.data[2]; + result.data[3] = data[3] / rhs.data[3]; + + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + + result.data[6] = data[6] / rhs.data[6]; + result.data[7] = data[7] / rhs.data[7]; + + return result; + } + + /// Scalar divide operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + + result.data[2] = data[2] / s; + result.data[3] = data[3] / s; + + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + + result.data[6] = data[6] / s; + result.data[7] = data[7] / s; + + return result; + } + + /// Scalar divide operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + + data[2] /= s; + data[3] /= s; + + data[4] /= s; + data[5] /= s; + + data[6] /= s; + data[7] /= s; + + return *this; + } + + /// Elementwise divide operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (4-by-2) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + + data[2] /= rhs.data[2]; + data[3] /= rhs.data[3]; + + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + + data[6] /= rhs.data[6]; + data[7] /= rhs.data[7]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + m.data[6] = -m.data[6]; + m.data[7] = -m.data[7]; + + return m; + } + + /// Matrix product of size 4-by-1-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[2] * rhs.data[0]; + accum.data[2] += data[4] * rhs.data[0]; + accum.data[3] += data[6] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[3] * rhs.data[1]; + accum.data[2] += data[5] * rhs.data[1]; + accum.data[3] += data[7] * rhs.data[1]; + + return accum; + } + + /// Matrix product of size 4-by-1-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[2] * rhs.data[0]; + accum.data[3] += data[2] * rhs.data[1]; + accum.data[4] += data[4] * rhs.data[0]; + accum.data[5] += data[4] * rhs.data[1]; + accum.data[6] += data[6] * rhs.data[0]; + accum.data[7] += data[6] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[3] * rhs.data[2]; + accum.data[3] += data[3] * rhs.data[3]; + accum.data[4] += data[5] * rhs.data[2]; + accum.data[5] += data[5] * rhs.data[3]; + accum.data[6] += data[7] * rhs.data[2]; + accum.data[7] += data[7] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 4-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-2-by-2 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 4-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[2] * rhs.data[0]; + accum.data[4] += data[2] * rhs.data[1]; + accum.data[5] += data[2] * rhs.data[2]; + accum.data[6] += data[4] * rhs.data[0]; + accum.data[7] += data[4] * rhs.data[1]; + accum.data[8] += data[4] * rhs.data[2]; + accum.data[9] += data[6] * rhs.data[0]; + accum.data[10] += data[6] * rhs.data[1]; + accum.data[11] += data[6] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[3] * rhs.data[3]; + accum.data[4] += data[3] * rhs.data[4]; + accum.data[5] += data[3] * rhs.data[5]; + accum.data[6] += data[5] * rhs.data[3]; + accum.data[7] += data[5] * rhs.data[4]; + accum.data[8] += data[5] * rhs.data[5]; + accum.data[9] += data[7] * rhs.data[3]; + accum.data[10] += data[7] * rhs.data[4]; + accum.data[11] += data[7] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 4-by-3-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[2] * rhs.data[0]; + accum.data[5] += data[2] * rhs.data[1]; + accum.data[6] += data[2] * rhs.data[2]; + accum.data[7] += data[2] * rhs.data[3]; + accum.data[8] += data[4] * rhs.data[0]; + accum.data[9] += data[4] * rhs.data[1]; + accum.data[10] += data[4] * rhs.data[2]; + accum.data[11] += data[4] * rhs.data[3]; + accum.data[12] += data[6] * rhs.data[0]; + accum.data[13] += data[6] * rhs.data[1]; + accum.data[14] += data[6] * rhs.data[2]; + accum.data[15] += data[6] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[3] * rhs.data[4]; + accum.data[5] += data[3] * rhs.data[5]; + accum.data[6] += data[3] * rhs.data[6]; + accum.data[7] += data[3] * rhs.data[7]; + accum.data[8] += data[5] * rhs.data[4]; + accum.data[9] += data[5] * rhs.data[5]; + accum.data[10] += data[5] * rhs.data[6]; + accum.data[11] += data[5] * rhs.data[7]; + accum.data[12] += data[7] * rhs.data[4]; + accum.data[13] += data[7] * rhs.data[5]; + accum.data[14] += data[7] * rhs.data[6]; + accum.data[15] += data[7] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 4-by-4-by-2 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + accum += data[6]; + accum += data[7]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + accum += data[6] * data[6]; + accum += data[7] * data[7]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[3]; + + return accum; + } + +}; + +/// Template alias for 4-by-2 matrix +template +using Matrix4x2 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix4x2 make_Matrix4x2( + Element _0_0, Element _0_1, + Element _1_0, Element _1_1, + Element _2_0, Element _2_1, + Element _3_0, Element _3_1 +) { + return Matrix4x2( + _0_0, _0_1, + _1_0, _1_1, + _2_0, _2_1, + _3_0, _3_1 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 4-by-3 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 4; + + /// Number of columns in matrix + static int const kColumns = 3; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 12; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 4-by-3 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 4-by-3 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2, + Element _1_0, Element _1_1, Element _1_2, + Element _2_0, Element _2_1, Element _2_2, + Element _3_0, Element _3_1, Element _3_2 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; + data[3] = _1_0; data[4] = _1_1; data[5] = _1_2; + data[6] = _2_0; data[7] = _2_1; data[8] = _2_2; + data[9] = _3_0; data[10] = _3_1; data[11] = _3_2; + } + + /// Constucts a 4-by-3 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1, + Matrix const &row_2, + Matrix const &row_3 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_0.data[2]; + data[3] = row_1.data[0]; + data[4] = row_1.data[1]; + data[5] = row_1.data[2]; + data[6] = row_2.data[0]; + data[7] = row_2.data[1]; + data[8] = row_2.data[2]; + data[9] = row_3.data[0]; + data[10] = row_3.data[1]; + data[11] = row_3.data[2]; + } + + /// Static method to construct a 4-by-3 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1, + Matrix const &column_2 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_2.data[0]; + result.data[3] = column_0.data[1]; + result.data[4] = column_1.data[1]; + result.data[5] = column_2.data[1]; + result.data[6] = column_0.data[2]; + result.data[7] = column_1.data[2]; + result.data[8] = column_2.data[2]; + result.data[9] = column_0.data[3]; + result.data[10] = column_1.data[3]; + result.data[11] = column_2.data[3]; + return result; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + m.data[6] = s; + m.data[7] = s; + m.data[8] = s; + m.data[9] = s; + m.data[10] = s; + m.data[11] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[5] = diag.data[1]; + m.data[10] = diag.data[2]; + m.data[15] = diag.data[3]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[5] = diag.data[1]; + m.data[10] = diag.data[2]; + m.data[15] = diag.data[3]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[5]; + diag.data[2] = data[10]; + diag.data[3] = data[15]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[4] = data[1]; + mt.data[8] = data[2]; + mt.data[1] = data[3]; + mt.data[5] = data[4]; + mt.data[9] = data[5]; + mt.data[2] = data[6]; + mt.data[6] = data[7]; + mt.data[10] = data[8]; + mt.data[3] = data[9]; + mt.data[7] = data[10]; + mt.data[11] = data[11]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x3(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x3(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 3] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 3]; + m.data[3] = data[i * 3 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 3] = m.data[2]; + data[i * 3 + j + 4] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + m.data[3] = data[i * 3 + j + 3]; + m.data[4] = data[i * 3 + j + 4]; + m.data[5] = data[i * 3 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + data[i * 3 + j + 3] = m.data[3]; + data[i * 3 + j + 4] = m.data[4]; + data[i * 3 + j + 5] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 3]; + m.data[2] = data[i * 3 + j + 6]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 3] = m.data[1]; + data[i * 3 + j + 6] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 3]; + m.data[3] = data[i * 3 + j + 4]; + m.data[4] = data[i * 3 + j + 6]; + m.data[5] = data[i * 3 + j + 7]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 3] = m.data[2]; + data[i * 3 + j + 4] = m.data[3]; + data[i * 3 + j + 6] = m.data[4]; + data[i * 3 + j + 7] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + m.data[3] = data[i * 3 + j + 3]; + m.data[4] = data[i * 3 + j + 4]; + m.data[5] = data[i * 3 + j + 5]; + m.data[6] = data[i * 3 + j + 6]; + m.data[7] = data[i * 3 + j + 7]; + m.data[8] = data[i * 3 + j + 8]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + data[i * 3 + j + 3] = m.data[3]; + data[i * 3 + j + 4] = m.data[4]; + data[i * 3 + j + 5] = m.data[5]; + data[i * 3 + j + 6] = m.data[6]; + data[i * 3 + j + 7] = m.data[7]; + data[i * 3 + j + 8] = m.data[8]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 3]; + m.data[2] = data[i * 3 + j + 6]; + m.data[3] = data[i * 3 + j + 9]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 3] = m.data[1]; + data[i * 3 + j + 6] = m.data[2]; + data[i * 3 + j + 9] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_4x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_4x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 3]; + m.data[3] = data[i * 3 + j + 4]; + m.data[4] = data[i * 3 + j + 6]; + m.data[5] = data[i * 3 + j + 7]; + m.data[6] = data[i * 3 + j + 9]; + m.data[7] = data[i * 3 + j + 10]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 3] = m.data[2]; + data[i * 3 + j + 4] = m.data[3]; + data[i * 3 + j + 6] = m.data[4]; + data[i * 3 + j + 7] = m.data[5]; + data[i * 3 + j + 9] = m.data[6]; + data[i * 3 + j + 10] = m.data[7]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 3 + j + 0]; + m.data[1] = data[i * 3 + j + 1]; + m.data[2] = data[i * 3 + j + 2]; + m.data[3] = data[i * 3 + j + 3]; + m.data[4] = data[i * 3 + j + 4]; + m.data[5] = data[i * 3 + j + 5]; + m.data[6] = data[i * 3 + j + 6]; + m.data[7] = data[i * 3 + j + 7]; + m.data[8] = data[i * 3 + j + 8]; + m.data[9] = data[i * 3 + j + 9]; + m.data[10] = data[i * 3 + j + 10]; + m.data[11] = data[i * 3 + j + 11]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 3 + j + 0] = m.data[0]; + data[i * 3 + j + 1] = m.data[1]; + data[i * 3 + j + 2] = m.data[2]; + data[i * 3 + j + 3] = m.data[3]; + data[i * 3 + j + 4] = m.data[4]; + data[i * 3 + j + 5] = m.data[5]; + data[i * 3 + j + 6] = m.data[6]; + data[i * 3 + j + 7] = m.data[7]; + data[i * 3 + j + 8] = m.data[8]; + data[i * 3 + j + 9] = m.data[9]; + data[i * 3 + j + 10] = m.data[10]; + data[i * 3 + j + 11] = m.data[11]; + + return *this; + } + + /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1) + , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1) + , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1) + , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1)); + } + + /// Forms a 4-by-3 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0) + , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0) + , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0) + , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0)); + } + + /// Concatenates this matrix with a a 4-by-1 matrix to form a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix hcat(Matrix const & rhs) const { + return Matrix::hcat(*this, rhs); + } + + /// Forms a 4-by-3 matrix by vertically concatenating a 1-by-3 matrix with a 3-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) + , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2) + , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2)); + } + + /// Forms a 4-by-3 matrix by vertically concatenating a 2-by-3 matrix with a 2-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) + , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2) + , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2)); + } + + /// Forms a 4-by-3 matrix by vertically concatenating a 3-by-3 matrix with a 1-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2) + , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2) + , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2)); + } + + /// Forms a 4-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A, B.at(0, 0), B.at(0, 1) + , C.at(0, 0), D.at(0, 0), D.at(0, 1) + , C.at(1, 0), D.at(1, 0), D.at(1, 1) + , C.at(2, 0), D.at(2, 0), D.at(2, 1) + ); + } + + /// Forms a 4-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Element B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B + , C.at(0, 0), C.at(0, 1), D.at(0, 0) + , C.at(1, 0), C.at(1, 1), D.at(1, 0) + , C.at(2, 0), C.at(2, 1), D.at(2, 0) + ); + } + + /// Forms a 4-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), B.at(0, 0), B.at(0, 1) + , A.at(1, 0), B.at(1, 0), B.at(1, 1) + , C.at(0, 0), D.at(0, 0), D.at(0, 1) + , C.at(1, 0), D.at(1, 0), D.at(1, 1) + ); + } + + /// Forms a 4-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0) + , A.at(1, 0), A.at(1, 1), B.at(1, 0) + , C.at(0, 0), C.at(0, 1), D.at(0, 0) + , C.at(1, 0), C.at(1, 1), D.at(1, 0) + ); + } + + /// Forms a 4-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Element C, Matrix const & D) { + return Matrix( + A.at(0, 0), B.at(0, 0), B.at(0, 1) + , A.at(1, 0), B.at(1, 0), B.at(1, 1) + , A.at(2, 0), B.at(2, 0), B.at(2, 1) + , C, D.at(0, 0), D.at(0, 1) + ); + } + + /// Forms a 4-by-3 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Element D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0) + , A.at(1, 0), A.at(1, 1), B.at(1, 0) + , A.at(2, 0), A.at(2, 1), B.at(2, 0) + , C.at(0, 0), C.at(0, 1), D + ); + } + + /// Elementwise add operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + + result.data[3] = data[3] + rhs.data[3]; + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + + result.data[6] = data[6] + rhs.data[6]; + result.data[7] = data[7] + rhs.data[7]; + result.data[8] = data[8] + rhs.data[8]; + + result.data[9] = data[9] + rhs.data[9]; + result.data[10] = data[10] + rhs.data[10]; + result.data[11] = data[11] + rhs.data[11]; + + return result; + } + + /// Elementwise add operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + + data[3] += rhs.data[3]; + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + + data[6] += rhs.data[6]; + data[7] += rhs.data[7]; + data[8] += rhs.data[8]; + + data[9] += rhs.data[9]; + data[10] += rhs.data[10]; + data[11] += rhs.data[11]; + + return *this; + } + + /// Elementwise subtract operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + + result.data[3] = data[3] - rhs.data[3]; + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + + result.data[6] = data[6] - rhs.data[6]; + result.data[7] = data[7] - rhs.data[7]; + result.data[8] = data[8] - rhs.data[8]; + + result.data[9] = data[9] - rhs.data[9]; + result.data[10] = data[10] - rhs.data[10]; + result.data[11] = data[11] - rhs.data[11]; + + return result; + } + + /// Elementwise subtract operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + + data[3] -= rhs.data[3]; + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + + data[6] -= rhs.data[6]; + data[7] -= rhs.data[7]; + data[8] -= rhs.data[8]; + + data[9] -= rhs.data[9]; + data[10] -= rhs.data[10]; + data[11] -= rhs.data[11]; + + return *this; + } + + /// Elementwise multiply operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + + result.data[3] = data[3] * rhs.data[3]; + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + + result.data[6] = data[6] * rhs.data[6]; + result.data[7] = data[7] * rhs.data[7]; + result.data[8] = data[8] * rhs.data[8]; + + result.data[9] = data[9] * rhs.data[9]; + result.data[10] = data[10] * rhs.data[10]; + result.data[11] = data[11] * rhs.data[11]; + + return result; + } + + /// Scalar multiply operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + + result.data[3] = data[3] * s; + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + + result.data[6] = data[6] * s; + result.data[7] = data[7] * s; + result.data[8] = data[8] * s; + + result.data[9] = data[9] * s; + result.data[10] = data[10] * s; + result.data[11] = data[11] * s; + + return result; + } + + /// Scalar multiply operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + + data[3] *= s; + data[4] *= s; + data[5] *= s; + + data[6] *= s; + data[7] *= s; + data[8] *= s; + + data[9] *= s; + data[10] *= s; + data[11] *= s; + + return *this; + } + + /// Elementwise divide operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + + result.data[3] = data[3] / rhs.data[3]; + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + + result.data[6] = data[6] / rhs.data[6]; + result.data[7] = data[7] / rhs.data[7]; + result.data[8] = data[8] / rhs.data[8]; + + result.data[9] = data[9] / rhs.data[9]; + result.data[10] = data[10] / rhs.data[10]; + result.data[11] = data[11] / rhs.data[11]; + + return result; + } + + /// Scalar divide operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + + result.data[3] = data[3] / s; + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + + result.data[6] = data[6] / s; + result.data[7] = data[7] / s; + result.data[8] = data[8] / s; + + result.data[9] = data[9] / s; + result.data[10] = data[10] / s; + result.data[11] = data[11] / s; + + return result; + } + + /// Scalar divide operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + + data[3] /= s; + data[4] /= s; + data[5] /= s; + + data[6] /= s; + data[7] /= s; + data[8] /= s; + + data[9] /= s; + data[10] /= s; + data[11] /= s; + + return *this; + } + + /// Elementwise divide operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (4-by-3) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + + data[3] /= rhs.data[3]; + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + + data[6] /= rhs.data[6]; + data[7] /= rhs.data[7]; + data[8] /= rhs.data[8]; + + data[9] /= rhs.data[9]; + data[10] /= rhs.data[10]; + data[11] /= rhs.data[11]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + m.data[6] = -m.data[6]; + m.data[7] = -m.data[7]; + m.data[8] = -m.data[8]; + m.data[9] = -m.data[9]; + m.data[10] = -m.data[10]; + m.data[11] = -m.data[11]; + + return m; + } + + /// Matrix product of size 4-by-1-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[3] * rhs.data[0]; + accum.data[2] += data[6] * rhs.data[0]; + accum.data[3] += data[9] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[4] * rhs.data[1]; + accum.data[2] += data[7] * rhs.data[1]; + accum.data[3] += data[10] * rhs.data[1]; + + // k=2 + accum.data[0] += data[2] * rhs.data[2]; + accum.data[1] += data[5] * rhs.data[2]; + accum.data[2] += data[8] * rhs.data[2]; + accum.data[3] += data[11] * rhs.data[2]; + + return accum; + } + + /// Matrix product of size 4-by-1-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[3] * rhs.data[0]; + accum.data[3] += data[3] * rhs.data[1]; + accum.data[4] += data[6] * rhs.data[0]; + accum.data[5] += data[6] * rhs.data[1]; + accum.data[6] += data[9] * rhs.data[0]; + accum.data[7] += data[9] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[4] * rhs.data[2]; + accum.data[3] += data[4] * rhs.data[3]; + accum.data[4] += data[7] * rhs.data[2]; + accum.data[5] += data[7] * rhs.data[3]; + accum.data[6] += data[10] * rhs.data[2]; + accum.data[7] += data[10] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + accum.data[2] += data[5] * rhs.data[4]; + accum.data[3] += data[5] * rhs.data[5]; + accum.data[4] += data[8] * rhs.data[4]; + accum.data[5] += data[8] * rhs.data[5]; + accum.data[6] += data[11] * rhs.data[4]; + accum.data[7] += data[11] * rhs.data[5]; + + return accum; + } + + /// Matrix product of size 4-by-2-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[3] * rhs.data[0]; + accum.data[4] += data[3] * rhs.data[1]; + accum.data[5] += data[3] * rhs.data[2]; + accum.data[6] += data[6] * rhs.data[0]; + accum.data[7] += data[6] * rhs.data[1]; + accum.data[8] += data[6] * rhs.data[2]; + accum.data[9] += data[9] * rhs.data[0]; + accum.data[10] += data[9] * rhs.data[1]; + accum.data[11] += data[9] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[4] * rhs.data[3]; + accum.data[4] += data[4] * rhs.data[4]; + accum.data[5] += data[4] * rhs.data[5]; + accum.data[6] += data[7] * rhs.data[3]; + accum.data[7] += data[7] * rhs.data[4]; + accum.data[8] += data[7] * rhs.data[5]; + accum.data[9] += data[10] * rhs.data[3]; + accum.data[10] += data[10] * rhs.data[4]; + accum.data[11] += data[10] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + accum.data[3] += data[5] * rhs.data[6]; + accum.data[4] += data[5] * rhs.data[7]; + accum.data[5] += data[5] * rhs.data[8]; + accum.data[6] += data[8] * rhs.data[6]; + accum.data[7] += data[8] * rhs.data[7]; + accum.data[8] += data[8] * rhs.data[8]; + accum.data[9] += data[11] * rhs.data[6]; + accum.data[10] += data[11] * rhs.data[7]; + accum.data[11] += data[11] * rhs.data[8]; + + return accum; + } + + /// Matrix product of size 4-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-3-by-3 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Matrix product of size 4-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[3] * rhs.data[0]; + accum.data[5] += data[3] * rhs.data[1]; + accum.data[6] += data[3] * rhs.data[2]; + accum.data[7] += data[3] * rhs.data[3]; + accum.data[8] += data[6] * rhs.data[0]; + accum.data[9] += data[6] * rhs.data[1]; + accum.data[10] += data[6] * rhs.data[2]; + accum.data[11] += data[6] * rhs.data[3]; + accum.data[12] += data[9] * rhs.data[0]; + accum.data[13] += data[9] * rhs.data[1]; + accum.data[14] += data[9] * rhs.data[2]; + accum.data[15] += data[9] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[4] * rhs.data[4]; + accum.data[5] += data[4] * rhs.data[5]; + accum.data[6] += data[4] * rhs.data[6]; + accum.data[7] += data[4] * rhs.data[7]; + accum.data[8] += data[7] * rhs.data[4]; + accum.data[9] += data[7] * rhs.data[5]; + accum.data[10] += data[7] * rhs.data[6]; + accum.data[11] += data[7] * rhs.data[7]; + accum.data[12] += data[10] * rhs.data[4]; + accum.data[13] += data[10] * rhs.data[5]; + accum.data[14] += data[10] * rhs.data[6]; + accum.data[15] += data[10] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + accum.data[4] += data[5] * rhs.data[8]; + accum.data[5] += data[5] * rhs.data[9]; + accum.data[6] += data[5] * rhs.data[10]; + accum.data[7] += data[5] * rhs.data[11]; + accum.data[8] += data[8] * rhs.data[8]; + accum.data[9] += data[8] * rhs.data[9]; + accum.data[10] += data[8] * rhs.data[10]; + accum.data[11] += data[8] * rhs.data[11]; + accum.data[12] += data[11] * rhs.data[8]; + accum.data[13] += data[11] * rhs.data[9]; + accum.data[14] += data[11] * rhs.data[10]; + accum.data[15] += data[11] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 4-by-4-by-3 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + accum += data[6]; + accum += data[7]; + accum += data[8]; + accum += data[9]; + accum += data[10]; + accum += data[11]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + accum += data[6] * data[6]; + accum += data[7] * data[7]; + accum += data[8] * data[8]; + accum += data[9] * data[9]; + accum += data[10] * data[10]; + accum += data[11] * data[11]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[4]; + accum += data[8]; + + return accum; + } + +}; + +/// Template alias for 4-by-3 matrix +template +using Matrix4x3 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix4x3 make_Matrix4x3( + Element _0_0, Element _0_1, Element _0_2, + Element _1_0, Element _1_1, Element _1_2, + Element _2_0, Element _2_1, Element _2_2, + Element _3_0, Element _3_1, Element _3_2 +) { + return Matrix4x3( + _0_0, _0_1, _0_2, + _1_0, _1_1, _1_2, + _2_0, _2_1, _2_2, + _3_0, _3_1, _3_2 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// 4-by-4 matrix template class definition +template +struct Matrix { + + // + // Type definitions + // + + /// Element data type + using Element = Element_; + + /// Number of rows in matrix + static int const kRows = 4; + + /// Number of columns in matrix + static int const kColumns = 4; + + /// Layout of matrix in underlying array + using Layout = layout::RowMajor; + + /// Number of elements in matrix + static int const kCount = 16; + + // + // Data members + // + + /// Elements of the matrix in row-major layout + Array data; + + // + // Methods + // + + /// Constructs a zero matrix + CUTLASS_HOST_DEVICE + Matrix() { + data.clear(); + } + + /// Copy constructor for a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Matrix(Matrix const &rhs) { + data = rhs.data; + } + + /// Constucts a 4-by-4 matrix from scalar elements + CUTLASS_HOST_DEVICE + Matrix( + Element _0_0, Element _0_1, Element _0_2, Element _0_3, + Element _1_0, Element _1_1, Element _1_2, Element _1_3, + Element _2_0, Element _2_1, Element _2_2, Element _2_3, + Element _3_0, Element _3_1, Element _3_2, Element _3_3 + ) { + + data[0] = _0_0; data[1] = _0_1; data[2] = _0_2; data[3] = _0_3; + data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3; + data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3; + data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3; + } + + /// Constucts a 4-by-4 matrix from row vectors + CUTLASS_HOST_DEVICE + Matrix( + Matrix const &row_0, + Matrix const &row_1, + Matrix const &row_2, + Matrix const &row_3 + ) { + data[0] = row_0.data[0]; + data[1] = row_0.data[1]; + data[2] = row_0.data[2]; + data[3] = row_0.data[3]; + data[4] = row_1.data[0]; + data[5] = row_1.data[1]; + data[6] = row_1.data[2]; + data[7] = row_1.data[3]; + data[8] = row_2.data[0]; + data[9] = row_2.data[1]; + data[10] = row_2.data[2]; + data[11] = row_2.data[3]; + data[12] = row_3.data[0]; + data[13] = row_3.data[1]; + data[14] = row_3.data[2]; + data[15] = row_3.data[3]; + } + + /// Static method to construct a 4-by-4 matrix from column vectors + CUTLASS_HOST_DEVICE + static Matrix from_columns( + Matrix const &column_0, + Matrix const &column_1, + Matrix const &column_2, + Matrix const &column_3 + ) { + Matrix result; + + result.data[0] = column_0.data[0]; + result.data[1] = column_1.data[0]; + result.data[2] = column_2.data[0]; + result.data[3] = column_3.data[0]; + result.data[4] = column_0.data[1]; + result.data[5] = column_1.data[1]; + result.data[6] = column_2.data[1]; + result.data[7] = column_3.data[1]; + result.data[8] = column_0.data[2]; + result.data[9] = column_1.data[2]; + result.data[10] = column_2.data[2]; + result.data[11] = column_3.data[2]; + result.data[12] = column_0.data[3]; + result.data[13] = column_1.data[3]; + result.data[14] = column_2.data[3]; + result.data[15] = column_3.data[3]; + return result; + } + + /// Constructs an identity matrix + CUTLASS_HOST_DEVICE + static Matrix identity() { + Matrix m; + + m.data[0] = Element(1); + m.data[5] = Element(1); + m.data[10] = Element(1); + m.data[15] = Element(1); + + return m; + } + + /// Constructs a matrix from a uniform element + CUTLASS_HOST_DEVICE + static Matrix uniform(Element s) { + Matrix m; + + m.data[0] = s; + m.data[1] = s; + m.data[2] = s; + m.data[3] = s; + m.data[4] = s; + m.data[5] = s; + m.data[6] = s; + m.data[7] = s; + m.data[8] = s; + m.data[9] = s; + m.data[10] = s; + m.data[11] = s; + m.data[12] = s; + m.data[13] = s; + m.data[14] = s; + m.data[15] = s; + + return m; + } + + /// Constructs a matrix from a uniform element 1 + CUTLASS_HOST_DEVICE + static Matrix ones() { + return uniform(Element(1)); + } + + /// Constructs a matrix from a uniform element 0 + CUTLASS_HOST_DEVICE + static Matrix zero() { + return Matrix(); + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[5] = diag.data[1]; + m.data[10] = diag.data[2]; + m.data[15] = diag.data[3]; + + return m; + } + + /// Constructs a matrix from elements along its diagonal + CUTLASS_HOST_DEVICE + static Matrix from_diagonal(Matrix const &diag) { + Matrix m; + + m.data[0] = diag.data[0]; + m.data[5] = diag.data[1]; + m.data[10] = diag.data[2]; + m.data[15] = diag.data[3]; + + return m; + } + + /// Gets an array of diagonal elements + CUTLASS_HOST_DEVICE + Matrix diagonal() const { + Matrix diag; + + diag.data[0] = data[0]; + diag.data[1] = data[5]; + diag.data[2] = data[10]; + diag.data[3] = data[15]; + + return diag; + } + + /// Returns a transposed matrix + CUTLASS_HOST_DEVICE + Matrix transpose() const { + Matrix mt; + + mt.data[0] = data[0]; + mt.data[4] = data[1]; + mt.data[8] = data[2]; + mt.data[12] = data[3]; + mt.data[1] = data[4]; + mt.data[5] = data[5]; + mt.data[9] = data[6]; + mt.data[13] = data[7]; + mt.data[2] = data[8]; + mt.data[6] = data[9]; + mt.data[10] = data[10]; + mt.data[14] = data[11]; + mt.data[3] = data[12]; + mt.data[7] = data[13]; + mt.data[11] = data[14]; + mt.data[15] = data[15]; + + return mt; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(int i, int j) const { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(int i, int j) { + return data[i * 4 + j]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element at(Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & at(Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element &at(int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element at(int offset) const { + return data[offset]; + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element operator[](Coord<2> const &coord) const { + return at(coord[0], coord[1]); + } + + /// Accesses an element by coordinate + CUTLASS_HOST_DEVICE + Element & operator[](Coord<2> const &coord) { + return at(coord[0], coord[1]); + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element & operator[](int offset) { + return data[offset]; + } + + /// Accesses an element by offset + CUTLASS_HOST_DEVICE + Element operator[](int offset) const { + return data[offset]; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_1x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_1x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix row(int i) const { + return slice_1x4(i, 0); + } + + Matrix &set_row(Matrix const &v, int i = 0) { + return set_slice_1x4(v, i, 0); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 4]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 4] = m.data[1]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 4]; + m.data[3] = data[i * 4 + j + 5]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 4] = m.data[2]; + data[i * 4 + j + 5] = m.data[3]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 4]; + m.data[4] = data[i * 4 + j + 5]; + m.data[5] = data[i * 4 + j + 6]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 4] = m.data[3]; + data[i * 4 + j + 5] = m.data[4]; + data[i * 4 + j + 6] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_2x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + m.data[4] = data[i * 4 + j + 4]; + m.data[5] = data[i * 4 + j + 5]; + m.data[6] = data[i * 4 + j + 6]; + m.data[7] = data[i * 4 + j + 7]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_2x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + data[i * 4 + j + 4] = m.data[4]; + data[i * 4 + j + 5] = m.data[5]; + data[i * 4 + j + 6] = m.data[6]; + data[i * 4 + j + 7] = m.data[7]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 4]; + m.data[2] = data[i * 4 + j + 8]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 4] = m.data[1]; + data[i * 4 + j + 8] = m.data[2]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 4]; + m.data[3] = data[i * 4 + j + 5]; + m.data[4] = data[i * 4 + j + 8]; + m.data[5] = data[i * 4 + j + 9]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 4] = m.data[2]; + data[i * 4 + j + 5] = m.data[3]; + data[i * 4 + j + 8] = m.data[4]; + data[i * 4 + j + 9] = m.data[5]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 4]; + m.data[4] = data[i * 4 + j + 5]; + m.data[5] = data[i * 4 + j + 6]; + m.data[6] = data[i * 4 + j + 8]; + m.data[7] = data[i * 4 + j + 9]; + m.data[8] = data[i * 4 + j + 10]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 4] = m.data[3]; + data[i * 4 + j + 5] = m.data[4]; + data[i * 4 + j + 6] = m.data[5]; + data[i * 4 + j + 8] = m.data[6]; + data[i * 4 + j + 9] = m.data[7]; + data[i * 4 + j + 10] = m.data[8]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_3x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + m.data[4] = data[i * 4 + j + 4]; + m.data[5] = data[i * 4 + j + 5]; + m.data[6] = data[i * 4 + j + 6]; + m.data[7] = data[i * 4 + j + 7]; + m.data[8] = data[i * 4 + j + 8]; + m.data[9] = data[i * 4 + j + 9]; + m.data[10] = data[i * 4 + j + 10]; + m.data[11] = data[i * 4 + j + 11]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_3x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + data[i * 4 + j + 4] = m.data[4]; + data[i * 4 + j + 5] = m.data[5]; + data[i * 4 + j + 6] = m.data[6]; + data[i * 4 + j + 7] = m.data[7]; + data[i * 4 + j + 8] = m.data[8]; + data[i * 4 + j + 9] = m.data[9]; + data[i * 4 + j + 10] = m.data[10]; + data[i * 4 + j + 11] = m.data[11]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x1(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 4]; + m.data[2] = data[i * 4 + j + 8]; + m.data[3] = data[i * 4 + j + 12]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x1(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 4] = m.data[1]; + data[i * 4 + j + 8] = m.data[2]; + data[i * 4 + j + 12] = m.data[3]; + + return *this; + } + + CUTLASS_HOST_DEVICE + Matrix column(int j) const { + return slice_4x1(0, j); + } + + Matrix &set_column(Matrix const &v, int j =0) { + return set_slice_4x1(v, 0, j); + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x2(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 4]; + m.data[3] = data[i * 4 + j + 5]; + m.data[4] = data[i * 4 + j + 8]; + m.data[5] = data[i * 4 + j + 9]; + m.data[6] = data[i * 4 + j + 12]; + m.data[7] = data[i * 4 + j + 13]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x2(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 4] = m.data[2]; + data[i * 4 + j + 5] = m.data[3]; + data[i * 4 + j + 8] = m.data[4]; + data[i * 4 + j + 9] = m.data[5]; + data[i * 4 + j + 12] = m.data[6]; + data[i * 4 + j + 13] = m.data[7]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x3(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 4]; + m.data[4] = data[i * 4 + j + 5]; + m.data[5] = data[i * 4 + j + 6]; + m.data[6] = data[i * 4 + j + 8]; + m.data[7] = data[i * 4 + j + 9]; + m.data[8] = data[i * 4 + j + 10]; + m.data[9] = data[i * 4 + j + 12]; + m.data[10] = data[i * 4 + j + 13]; + m.data[11] = data[i * 4 + j + 14]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x3(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 4] = m.data[3]; + data[i * 4 + j + 5] = m.data[4]; + data[i * 4 + j + 6] = m.data[5]; + data[i * 4 + j + 8] = m.data[6]; + data[i * 4 + j + 9] = m.data[7]; + data[i * 4 + j + 10] = m.data[8]; + data[i * 4 + j + 12] = m.data[9]; + data[i * 4 + j + 13] = m.data[10]; + data[i * 4 + j + 14] = m.data[11]; + + return *this; + } + + /// Gets a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix slice_4x4(int i = 0, int j = 0) const { + Matrix m; + + m.data[0] = data[i * 4 + j + 0]; + m.data[1] = data[i * 4 + j + 1]; + m.data[2] = data[i * 4 + j + 2]; + m.data[3] = data[i * 4 + j + 3]; + m.data[4] = data[i * 4 + j + 4]; + m.data[5] = data[i * 4 + j + 5]; + m.data[6] = data[i * 4 + j + 6]; + m.data[7] = data[i * 4 + j + 7]; + m.data[8] = data[i * 4 + j + 8]; + m.data[9] = data[i * 4 + j + 9]; + m.data[10] = data[i * 4 + j + 10]; + m.data[11] = data[i * 4 + j + 11]; + m.data[12] = data[i * 4 + j + 12]; + m.data[13] = data[i * 4 + j + 13]; + m.data[14] = data[i * 4 + j + 14]; + m.data[15] = data[i * 4 + j + 15]; + + return m; + } + + /// Overwrites a submatrix with optional offset + CUTLASS_HOST_DEVICE + Matrix & set_slice_4x4(Matrix const &m, int i = 0, int j = 0) { + + data[i * 4 + j + 0] = m.data[0]; + data[i * 4 + j + 1] = m.data[1]; + data[i * 4 + j + 2] = m.data[2]; + data[i * 4 + j + 3] = m.data[3]; + data[i * 4 + j + 4] = m.data[4]; + data[i * 4 + j + 5] = m.data[5]; + data[i * 4 + j + 6] = m.data[6]; + data[i * 4 + j + 7] = m.data[7]; + data[i * 4 + j + 8] = m.data[8]; + data[i * 4 + j + 9] = m.data[9]; + data[i * 4 + j + 10] = m.data[10]; + data[i * 4 + j + 11] = m.data[11]; + data[i * 4 + j + 12] = m.data[12]; + data[i * 4 + j + 13] = m.data[13]; + data[i * 4 + j + 14] = m.data[14]; + data[i * 4 + j + 15] = m.data[15]; + + return *this; + } + + /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-1 matrix with a 4-by-3 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), rhs.at(0, 0), rhs.at(0, 1), rhs.at(0, 2) + , lhs.at(1, 0), rhs.at(1, 0), rhs.at(1, 1), rhs.at(1, 2) + , lhs.at(2, 0), rhs.at(2, 0), rhs.at(2, 1), rhs.at(2, 2) + , lhs.at(3, 0), rhs.at(3, 0), rhs.at(3, 1), rhs.at(3, 2)); + } + + /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-2 matrix with a 4-by-2 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), rhs.at(0, 0), rhs.at(0, 1) + , lhs.at(1, 0), lhs.at(1, 1), rhs.at(1, 0), rhs.at(1, 1) + , lhs.at(2, 0), lhs.at(2, 1), rhs.at(2, 0), rhs.at(2, 1) + , lhs.at(3, 0), lhs.at(3, 1), rhs.at(3, 0), rhs.at(3, 1)); + } + + /// Forms a 4-by-4 matrix by horizontally concatenating a 4-by-3 matrix with a 4-by-1 matrix + CUTLASS_HOST_DEVICE + static Matrix hcat(Matrix const & lhs, Matrix const & rhs) { + return Matrix( + lhs.at(0, 0), lhs.at(0, 1), lhs.at(0, 2), rhs.at(0, 0) + , lhs.at(1, 0), lhs.at(1, 1), lhs.at(1, 2), rhs.at(1, 0) + , lhs.at(2, 0), lhs.at(2, 1), lhs.at(2, 2), rhs.at(2, 0) + , lhs.at(3, 0), lhs.at(3, 1), lhs.at(3, 2), rhs.at(3, 0)); + } + + /// Forms a 4-by-4 matrix by vertically concatenating a 1-by-4 matrix with a 3-by-4 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) + , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3) + , lower.at(2, 0), lower.at(2, 1), lower.at(2, 2), lower.at(2, 3)); + } + + /// Forms a 4-by-4 matrix by vertically concatenating a 2-by-4 matrix with a 2-by-4 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) + , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3) + , lower.at(1, 0), lower.at(1, 1), lower.at(1, 2), lower.at(1, 3)); + } + + /// Forms a 4-by-4 matrix by vertically concatenating a 3-by-4 matrix with a 1-by-4 matrix + CUTLASS_HOST_DEVICE + static Matrix vcat(Matrix const & upper, Matrix const & lower) { + return Matrix( + upper.at(0, 0), upper.at(0, 1), upper.at(0, 2), upper.at(0, 3) + , upper.at(1, 0), upper.at(1, 1), upper.at(1, 2), upper.at(1, 3) + , upper.at(2, 0), upper.at(2, 1), upper.at(2, 2), upper.at(2, 3) + , lower.at(0, 0), lower.at(0, 1), lower.at(0, 2), lower.at(0, 3)); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Element A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A, B.at(0, 0), B.at(0, 1), B.at(0, 2) + , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) + , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) + , C.at(2, 0), D.at(2, 0), D.at(2, 1), D.at(2, 2) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) + , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) + , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) + , C.at(2, 0), C.at(2, 1), D.at(2, 0), D.at(2, 1) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Element B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), A.at(0, 2), B + , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) + , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) + , C.at(2, 0), C.at(2, 1), C.at(2, 2), D.at(2, 0) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) + , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) + , C.at(0, 0), D.at(0, 0), D.at(0, 1), D.at(0, 2) + , C.at(1, 0), D.at(1, 0), D.at(1, 1), D.at(1, 2) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) + , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) + , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) + , C.at(1, 0), C.at(1, 1), D.at(1, 0), D.at(1, 1) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) + , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) + , C.at(0, 0), C.at(0, 1), C.at(0, 2), D.at(0, 0) + , C.at(1, 0), C.at(1, 1), C.at(1, 2), D.at(1, 0) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Element C, Matrix const & D) { + return Matrix( + A.at(0, 0), B.at(0, 0), B.at(0, 1), B.at(0, 2) + , A.at(1, 0), B.at(1, 0), B.at(1, 1), B.at(1, 2) + , A.at(2, 0), B.at(2, 0), B.at(2, 1), B.at(2, 2) + , C, D.at(0, 0), D.at(0, 1), D.at(0, 2) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Matrix const & D) { + return Matrix( + A.at(0, 0), A.at(0, 1), B.at(0, 0), B.at(0, 1) + , A.at(1, 0), A.at(1, 1), B.at(1, 0), B.at(1, 1) + , A.at(2, 0), A.at(2, 1), B.at(2, 0), B.at(2, 1) + , C.at(0, 0), C.at(0, 1), D.at(0, 0), D.at(0, 1) + ); + } + + /// Forms a 4-by-4 matrix by concatenating four components + CUTLASS_HOST_DEVICE + static Matrix block( + Matrix const & A, Matrix const & B, + Matrix const & C, Element D) { + return Matrix( + A.at(0, 0), A.at(0, 1), A.at(0, 2), B.at(0, 0) + , A.at(1, 0), A.at(1, 1), A.at(1, 2), B.at(1, 0) + , A.at(2, 0), A.at(2, 1), A.at(2, 2), B.at(2, 0) + , C.at(0, 0), C.at(0, 1), C.at(0, 2), D + ); + } + + /// Elementwise add operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix add(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] + rhs.data[0]; + result.data[1] = data[1] + rhs.data[1]; + result.data[2] = data[2] + rhs.data[2]; + result.data[3] = data[3] + rhs.data[3]; + + result.data[4] = data[4] + rhs.data[4]; + result.data[5] = data[5] + rhs.data[5]; + result.data[6] = data[6] + rhs.data[6]; + result.data[7] = data[7] + rhs.data[7]; + + result.data[8] = data[8] + rhs.data[8]; + result.data[9] = data[9] + rhs.data[9]; + result.data[10] = data[10] + rhs.data[10]; + result.data[11] = data[11] + rhs.data[11]; + + result.data[12] = data[12] + rhs.data[12]; + result.data[13] = data[13] + rhs.data[13]; + result.data[14] = data[14] + rhs.data[14]; + result.data[15] = data[15] + rhs.data[15]; + + return result; + } + + /// Elementwise add operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix operator +(Matrix const &rhs) const { + return add(rhs); + } + + /// Elementwise add operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator +=(Matrix const &rhs) { + + data[0] += rhs.data[0]; + data[1] += rhs.data[1]; + data[2] += rhs.data[2]; + data[3] += rhs.data[3]; + + data[4] += rhs.data[4]; + data[5] += rhs.data[5]; + data[6] += rhs.data[6]; + data[7] += rhs.data[7]; + + data[8] += rhs.data[8]; + data[9] += rhs.data[9]; + data[10] += rhs.data[10]; + data[11] += rhs.data[11]; + + data[12] += rhs.data[12]; + data[13] += rhs.data[13]; + data[14] += rhs.data[14]; + data[15] += rhs.data[15]; + + return *this; + } + + /// Elementwise subtract operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix subtract(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] - rhs.data[0]; + result.data[1] = data[1] - rhs.data[1]; + result.data[2] = data[2] - rhs.data[2]; + result.data[3] = data[3] - rhs.data[3]; + + result.data[4] = data[4] - rhs.data[4]; + result.data[5] = data[5] - rhs.data[5]; + result.data[6] = data[6] - rhs.data[6]; + result.data[7] = data[7] - rhs.data[7]; + + result.data[8] = data[8] - rhs.data[8]; + result.data[9] = data[9] - rhs.data[9]; + result.data[10] = data[10] - rhs.data[10]; + result.data[11] = data[11] - rhs.data[11]; + + result.data[12] = data[12] - rhs.data[12]; + result.data[13] = data[13] - rhs.data[13]; + result.data[14] = data[14] - rhs.data[14]; + result.data[15] = data[15] - rhs.data[15]; + + return result; + } + + /// Elementwise subtract operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix operator -(Matrix const &rhs) const { + return subtract(rhs); + } + + /// Elementwise subtract operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator -=(Matrix const &rhs) { + + data[0] -= rhs.data[0]; + data[1] -= rhs.data[1]; + data[2] -= rhs.data[2]; + data[3] -= rhs.data[3]; + + data[4] -= rhs.data[4]; + data[5] -= rhs.data[5]; + data[6] -= rhs.data[6]; + data[7] -= rhs.data[7]; + + data[8] -= rhs.data[8]; + data[9] -= rhs.data[9]; + data[10] -= rhs.data[10]; + data[11] -= rhs.data[11]; + + data[12] -= rhs.data[12]; + data[13] -= rhs.data[13]; + data[14] -= rhs.data[14]; + data[15] -= rhs.data[15]; + + return *this; + } + + /// Elementwise multiply operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] * rhs.data[0]; + result.data[1] = data[1] * rhs.data[1]; + result.data[2] = data[2] * rhs.data[2]; + result.data[3] = data[3] * rhs.data[3]; + + result.data[4] = data[4] * rhs.data[4]; + result.data[5] = data[5] * rhs.data[5]; + result.data[6] = data[6] * rhs.data[6]; + result.data[7] = data[7] * rhs.data[7]; + + result.data[8] = data[8] * rhs.data[8]; + result.data[9] = data[9] * rhs.data[9]; + result.data[10] = data[10] * rhs.data[10]; + result.data[11] = data[11] * rhs.data[11]; + + result.data[12] = data[12] * rhs.data[12]; + result.data[13] = data[13] * rhs.data[13]; + result.data[14] = data[14] * rhs.data[14]; + result.data[15] = data[15] * rhs.data[15]; + + return result; + } + + /// Scalar multiply operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix multiply(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] * s; + result.data[1] = data[1] * s; + result.data[2] = data[2] * s; + result.data[3] = data[3] * s; + + result.data[4] = data[4] * s; + result.data[5] = data[5] * s; + result.data[6] = data[6] * s; + result.data[7] = data[7] * s; + + result.data[8] = data[8] * s; + result.data[9] = data[9] * s; + result.data[10] = data[10] * s; + result.data[11] = data[11] * s; + + result.data[12] = data[12] * s; + result.data[13] = data[13] * s; + result.data[14] = data[14] * s; + result.data[15] = data[15] * s; + + return result; + } + + /// Scalar multiply operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix operator *(Element const &s) const { + return multiply(s); + } + + /// Scalar multiply operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator *=(Element const &s) { + + data[0] *= s; + data[1] *= s; + data[2] *= s; + data[3] *= s; + + data[4] *= s; + data[5] *= s; + data[6] *= s; + data[7] *= s; + + data[8] *= s; + data[9] *= s; + data[10] *= s; + data[11] *= s; + + data[12] *= s; + data[13] *= s; + data[14] *= s; + data[15] *= s; + + return *this; + } + + /// Elementwise divide operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Matrix const &rhs) const { + + Matrix result; + + result.data[0] = data[0] / rhs.data[0]; + result.data[1] = data[1] / rhs.data[1]; + result.data[2] = data[2] / rhs.data[2]; + result.data[3] = data[3] / rhs.data[3]; + + result.data[4] = data[4] / rhs.data[4]; + result.data[5] = data[5] / rhs.data[5]; + result.data[6] = data[6] / rhs.data[6]; + result.data[7] = data[7] / rhs.data[7]; + + result.data[8] = data[8] / rhs.data[8]; + result.data[9] = data[9] / rhs.data[9]; + result.data[10] = data[10] / rhs.data[10]; + result.data[11] = data[11] / rhs.data[11]; + + result.data[12] = data[12] / rhs.data[12]; + result.data[13] = data[13] / rhs.data[13]; + result.data[14] = data[14] / rhs.data[14]; + result.data[15] = data[15] / rhs.data[15]; + + return result; + } + + /// Scalar divide operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix divide(Element const &s) const { + + Matrix result; + + result.data[0] = data[0] / s; + result.data[1] = data[1] / s; + result.data[2] = data[2] / s; + result.data[3] = data[3] / s; + + result.data[4] = data[4] / s; + result.data[5] = data[5] / s; + result.data[6] = data[6] / s; + result.data[7] = data[7] / s; + + result.data[8] = data[8] / s; + result.data[9] = data[9] / s; + result.data[10] = data[10] / s; + result.data[11] = data[11] / s; + + result.data[12] = data[12] / s; + result.data[13] = data[13] / s; + result.data[14] = data[14] / s; + result.data[15] = data[15] / s; + + return result; + } + + /// Scalar divide operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Element const &s) const { + return divide(s); + } + + /// Scalar divide operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Element const &s) { + + data[0] /= s; + data[1] /= s; + data[2] /= s; + data[3] /= s; + + data[4] /= s; + data[5] /= s; + data[6] /= s; + data[7] /= s; + + data[8] /= s; + data[9] /= s; + data[10] /= s; + data[11] /= s; + + data[12] /= s; + data[13] /= s; + data[14] /= s; + data[15] /= s; + + return *this; + } + + /// Elementwise divide operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix operator /(Matrix const &rhs) const { + return divide(rhs); + } + + /// Elementwise divide operator (4-by-4) + CUTLASS_HOST_DEVICE + Matrix & operator /=(Matrix const &rhs) { + + data[0] /= rhs.data[0]; + data[1] /= rhs.data[1]; + data[2] /= rhs.data[2]; + data[3] /= rhs.data[3]; + + data[4] /= rhs.data[4]; + data[5] /= rhs.data[5]; + data[6] /= rhs.data[6]; + data[7] /= rhs.data[7]; + + data[8] /= rhs.data[8]; + data[9] /= rhs.data[9]; + data[10] /= rhs.data[10]; + data[11] /= rhs.data[11]; + + data[12] /= rhs.data[12]; + data[13] /= rhs.data[13]; + data[14] /= rhs.data[14]; + data[15] /= rhs.data[15]; + + return *this; + } + + /// Negates each element of the matrix + CUTLASS_HOST_DEVICE + Matrix operator-() const { + Matrix m; + + m.data[0] = -m.data[0]; + m.data[1] = -m.data[1]; + m.data[2] = -m.data[2]; + m.data[3] = -m.data[3]; + m.data[4] = -m.data[4]; + m.data[5] = -m.data[5]; + m.data[6] = -m.data[6]; + m.data[7] = -m.data[7]; + m.data[8] = -m.data[8]; + m.data[9] = -m.data[9]; + m.data[10] = -m.data[10]; + m.data[11] = -m.data[11]; + m.data[12] = -m.data[12]; + m.data[13] = -m.data[13]; + m.data[14] = -m.data[14]; + m.data[15] = -m.data[15]; + + return m; + } + + /// Matrix product of size 4-by-1-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[4] * rhs.data[0]; + accum.data[2] += data[8] * rhs.data[0]; + accum.data[3] += data[12] * rhs.data[0]; + + // k=1 + accum.data[0] += data[1] * rhs.data[1]; + accum.data[1] += data[5] * rhs.data[1]; + accum.data[2] += data[9] * rhs.data[1]; + accum.data[3] += data[13] * rhs.data[1]; + + // k=2 + accum.data[0] += data[2] * rhs.data[2]; + accum.data[1] += data[6] * rhs.data[2]; + accum.data[2] += data[10] * rhs.data[2]; + accum.data[3] += data[14] * rhs.data[2]; + + // k=3 + accum.data[0] += data[3] * rhs.data[3]; + accum.data[1] += data[7] * rhs.data[3]; + accum.data[2] += data[11] * rhs.data[3]; + accum.data[3] += data[15] * rhs.data[3]; + + return accum; + } + + /// Matrix product of size 4-by-1-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[4] * rhs.data[0]; + accum.data[3] += data[4] * rhs.data[1]; + accum.data[4] += data[8] * rhs.data[0]; + accum.data[5] += data[8] * rhs.data[1]; + accum.data[6] += data[12] * rhs.data[0]; + accum.data[7] += data[12] * rhs.data[1]; + + // k=1 + accum.data[0] += data[1] * rhs.data[2]; + accum.data[1] += data[1] * rhs.data[3]; + accum.data[2] += data[5] * rhs.data[2]; + accum.data[3] += data[5] * rhs.data[3]; + accum.data[4] += data[9] * rhs.data[2]; + accum.data[5] += data[9] * rhs.data[3]; + accum.data[6] += data[13] * rhs.data[2]; + accum.data[7] += data[13] * rhs.data[3]; + + // k=2 + accum.data[0] += data[2] * rhs.data[4]; + accum.data[1] += data[2] * rhs.data[5]; + accum.data[2] += data[6] * rhs.data[4]; + accum.data[3] += data[6] * rhs.data[5]; + accum.data[4] += data[10] * rhs.data[4]; + accum.data[5] += data[10] * rhs.data[5]; + accum.data[6] += data[14] * rhs.data[4]; + accum.data[7] += data[14] * rhs.data[5]; + + // k=3 + accum.data[0] += data[3] * rhs.data[6]; + accum.data[1] += data[3] * rhs.data[7]; + accum.data[2] += data[7] * rhs.data[6]; + accum.data[3] += data[7] * rhs.data[7]; + accum.data[4] += data[11] * rhs.data[6]; + accum.data[5] += data[11] * rhs.data[7]; + accum.data[6] += data[15] * rhs.data[6]; + accum.data[7] += data[15] * rhs.data[7]; + + return accum; + } + + /// Matrix product of size 4-by-2-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[4] * rhs.data[0]; + accum.data[4] += data[4] * rhs.data[1]; + accum.data[5] += data[4] * rhs.data[2]; + accum.data[6] += data[8] * rhs.data[0]; + accum.data[7] += data[8] * rhs.data[1]; + accum.data[8] += data[8] * rhs.data[2]; + accum.data[9] += data[12] * rhs.data[0]; + accum.data[10] += data[12] * rhs.data[1]; + accum.data[11] += data[12] * rhs.data[2]; + + // k=1 + accum.data[0] += data[1] * rhs.data[3]; + accum.data[1] += data[1] * rhs.data[4]; + accum.data[2] += data[1] * rhs.data[5]; + accum.data[3] += data[5] * rhs.data[3]; + accum.data[4] += data[5] * rhs.data[4]; + accum.data[5] += data[5] * rhs.data[5]; + accum.data[6] += data[9] * rhs.data[3]; + accum.data[7] += data[9] * rhs.data[4]; + accum.data[8] += data[9] * rhs.data[5]; + accum.data[9] += data[13] * rhs.data[3]; + accum.data[10] += data[13] * rhs.data[4]; + accum.data[11] += data[13] * rhs.data[5]; + + // k=2 + accum.data[0] += data[2] * rhs.data[6]; + accum.data[1] += data[2] * rhs.data[7]; + accum.data[2] += data[2] * rhs.data[8]; + accum.data[3] += data[6] * rhs.data[6]; + accum.data[4] += data[6] * rhs.data[7]; + accum.data[5] += data[6] * rhs.data[8]; + accum.data[6] += data[10] * rhs.data[6]; + accum.data[7] += data[10] * rhs.data[7]; + accum.data[8] += data[10] * rhs.data[8]; + accum.data[9] += data[14] * rhs.data[6]; + accum.data[10] += data[14] * rhs.data[7]; + accum.data[11] += data[14] * rhs.data[8]; + + // k=3 + accum.data[0] += data[3] * rhs.data[9]; + accum.data[1] += data[3] * rhs.data[10]; + accum.data[2] += data[3] * rhs.data[11]; + accum.data[3] += data[7] * rhs.data[9]; + accum.data[4] += data[7] * rhs.data[10]; + accum.data[5] += data[7] * rhs.data[11]; + accum.data[6] += data[11] * rhs.data[9]; + accum.data[7] += data[11] * rhs.data[10]; + accum.data[8] += data[11] * rhs.data[11]; + accum.data[9] += data[15] * rhs.data[9]; + accum.data[10] += data[15] * rhs.data[10]; + accum.data[11] += data[15] * rhs.data[11]; + + return accum; + } + + /// Matrix product of size 4-by-3-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix product( + Matrix const &rhs, + Matrix accum = Matrix() + ) const { + + // k=0 + accum.data[0] += data[0] * rhs.data[0]; + accum.data[1] += data[0] * rhs.data[1]; + accum.data[2] += data[0] * rhs.data[2]; + accum.data[3] += data[0] * rhs.data[3]; + accum.data[4] += data[4] * rhs.data[0]; + accum.data[5] += data[4] * rhs.data[1]; + accum.data[6] += data[4] * rhs.data[2]; + accum.data[7] += data[4] * rhs.data[3]; + accum.data[8] += data[8] * rhs.data[0]; + accum.data[9] += data[8] * rhs.data[1]; + accum.data[10] += data[8] * rhs.data[2]; + accum.data[11] += data[8] * rhs.data[3]; + accum.data[12] += data[12] * rhs.data[0]; + accum.data[13] += data[12] * rhs.data[1]; + accum.data[14] += data[12] * rhs.data[2]; + accum.data[15] += data[12] * rhs.data[3]; + + // k=1 + accum.data[0] += data[1] * rhs.data[4]; + accum.data[1] += data[1] * rhs.data[5]; + accum.data[2] += data[1] * rhs.data[6]; + accum.data[3] += data[1] * rhs.data[7]; + accum.data[4] += data[5] * rhs.data[4]; + accum.data[5] += data[5] * rhs.data[5]; + accum.data[6] += data[5] * rhs.data[6]; + accum.data[7] += data[5] * rhs.data[7]; + accum.data[8] += data[9] * rhs.data[4]; + accum.data[9] += data[9] * rhs.data[5]; + accum.data[10] += data[9] * rhs.data[6]; + accum.data[11] += data[9] * rhs.data[7]; + accum.data[12] += data[13] * rhs.data[4]; + accum.data[13] += data[13] * rhs.data[5]; + accum.data[14] += data[13] * rhs.data[6]; + accum.data[15] += data[13] * rhs.data[7]; + + // k=2 + accum.data[0] += data[2] * rhs.data[8]; + accum.data[1] += data[2] * rhs.data[9]; + accum.data[2] += data[2] * rhs.data[10]; + accum.data[3] += data[2] * rhs.data[11]; + accum.data[4] += data[6] * rhs.data[8]; + accum.data[5] += data[6] * rhs.data[9]; + accum.data[6] += data[6] * rhs.data[10]; + accum.data[7] += data[6] * rhs.data[11]; + accum.data[8] += data[10] * rhs.data[8]; + accum.data[9] += data[10] * rhs.data[9]; + accum.data[10] += data[10] * rhs.data[10]; + accum.data[11] += data[10] * rhs.data[11]; + accum.data[12] += data[14] * rhs.data[8]; + accum.data[13] += data[14] * rhs.data[9]; + accum.data[14] += data[14] * rhs.data[10]; + accum.data[15] += data[14] * rhs.data[11]; + + // k=3 + accum.data[0] += data[3] * rhs.data[12]; + accum.data[1] += data[3] * rhs.data[13]; + accum.data[2] += data[3] * rhs.data[14]; + accum.data[3] += data[3] * rhs.data[15]; + accum.data[4] += data[7] * rhs.data[12]; + accum.data[5] += data[7] * rhs.data[13]; + accum.data[6] += data[7] * rhs.data[14]; + accum.data[7] += data[7] * rhs.data[15]; + accum.data[8] += data[11] * rhs.data[12]; + accum.data[9] += data[11] * rhs.data[13]; + accum.data[10] += data[11] * rhs.data[14]; + accum.data[11] += data[11] * rhs.data[15]; + accum.data[12] += data[15] * rhs.data[12]; + accum.data[13] += data[15] * rhs.data[13]; + accum.data[14] += data[15] * rhs.data[14]; + accum.data[15] += data[15] * rhs.data[15]; + + return accum; + } + + /// Matrix product of size 4-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix operator*(Matrix const &rhs) const { + return product(rhs); + } + + /// Matrix product of size 4-by-4-by-4 + CUTLASS_HOST_DEVICE + Matrix & operator*=(Matrix const &rhs) { + *this = product(rhs); + return *this; + } + + /// Returns the sum of elements + CUTLASS_HOST_DEVICE + Element sum(Element accum = Element()) const { + + accum += data[0]; + accum += data[1]; + accum += data[2]; + accum += data[3]; + accum += data[4]; + accum += data[5]; + accum += data[6]; + accum += data[7]; + accum += data[8]; + accum += data[9]; + accum += data[10]; + accum += data[11]; + accum += data[12]; + accum += data[13]; + accum += data[14]; + accum += data[15]; + + return accum; + } + + /// Returns the sum of squared elements + CUTLASS_HOST_DEVICE + Element norm(Element accum = Element()) const { + + accum += data[0] * data[0]; + accum += data[1] * data[1]; + accum += data[2] * data[2]; + accum += data[3] * data[3]; + accum += data[4] * data[4]; + accum += data[5] * data[5]; + accum += data[6] * data[6]; + accum += data[7] * data[7]; + accum += data[8] * data[8]; + accum += data[9] * data[9]; + accum += data[10] * data[10]; + accum += data[11] * data[11]; + accum += data[12] * data[12]; + accum += data[13] * data[13]; + accum += data[14] * data[14]; + accum += data[15] * data[15]; + + return accum; + } + + /// Returns square root of the norm + CUTLASS_HOST_DEVICE + Element magnitude() const { + return fast_sqrt(norm()); + } + + /// Returns the sum of diagonal elements + CUTLASS_HOST_DEVICE + Element trace(Element accum = Element()) const { + + accum += data[0]; + accum += data[5]; + accum += data[10]; + accum += data[15]; + + return accum; + } + + /// Returns 4-by-4 rotation matrix around the X axis + CUTLASS_HOST_DEVICE + static Matrix rotation_X(Element theta) { + Matrix m = identity(); + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + m.at(1, 1) = c; + m.at(1, 2) = -s; + m.at(2, 1) = s; + m.at(2, 2) = c; + + return m; + } + + /// Returns 4-by-4 rotation matrix around the Y axis + CUTLASS_HOST_DEVICE + static Matrix rotation_Y(Element theta) { + Matrix m = identity(); + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + m.at(0, 0) = c; + m.at(2, 0) = -s; + m.at(0, 2) = s; + m.at(2, 2) = c; + + return m; + } + + /// Returns 4-by-4 rotation matrix around the Z axis + CUTLASS_HOST_DEVICE + static Matrix rotation_Z(Element theta) { + Matrix m = Matrix::identity(); + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + m.at(0, 0) = c; + m.at(0, 1) = -s; + m.at(1, 0) = s; + m.at(1, 1) = c; + + return m; + } + + /// Returns a 4-by-4 rotation matrix around a unit-length axis + CUTLASS_HOST_DEVICE + static Matrix rotation(Element theta, Matrix const &u) { + Element x = u.data[0]; + Element y = u.data[1]; + Element z = u.data[2]; + + Element c = fast_cos(theta); + Element s = fast_sin(theta); + + Element one_minus_cos = Element(1) - fast_cos(theta); + + Matrix m; + + m.set_slice3x3({ + c + x * x * one_minus_cos, x * y * one_minus_cos - z * s, x * z * one_minus_cos + y * s, + y * x * one_minus_cos * z * s, c + y * y * one_minus_cos, y * z * one_minus_cos - x * s, + z * x * one_minus_cos - y * s, z * y * one_minus_cos + x * s, c + z * z * one_minus_cos + }); + + return m; + } + + /// Returns a 4-by-4 reflection about the plane specified by the + /// unit-length normal vector n_unit + CUTLASS_HOST_DEVICE + static Matrix reflection(Matrix const &n_unit) { + + Element a = n_unit.data[0]; + Element b = n_unit.data[1]; + Element c = n_unit.data[2]; + + Matrix m = Matrix::identity(); + + m.set_slice3x3({ + Element(1) - Element(2) * a * a, Element(-2) * a * b, Element(-2) * a * c, + Element(-2) * a * b, Element(1) - Element(2) * b * b, Element(-2) * b * c, + Element(-2) * a * c, Element(-2) * b * c, Element(1) - Element(2) * c * c + }); + + return m; + } + + /// Returns a perspective projection matrix typical of OpenGL applications + CUTLASS_HOST_DEVICE + static Matrix perspective(Element near, Element far, Element fovH, Element fovV) { + Element aspect = fovH / fovV; + Element f = Element(cos(fovV)) / Element(fovH); + Element Q = near - far; + + return Matrix( + f / aspect, 0, 0, 0, + 0, f, 0, 0, + 0, 0, (near + far) / Q, Element(2) * far * near / Q, + 0, 0, -1, 0 + ); + } + + CUTLASS_HOST_DEVICE + static Matrix translation(Matrix const &v) { + return Matrix( + 1, 0, 0, v.data[0], + 0, 1, 0, v.data[1], + 0, 0, 1, v.data[2], + 0, 0, 0, 1 + ); + } + + /// Computes the determinant of a 4-by-4 matrix + CUTLASS_HOST_DEVICE + Element determinant(Element accum = Element()) const { + + accum += at(0, 0) * Matrix({ at(1, 1), at(1, 2), at(1, 3), at(2, 1), at(2, 2), at(2, 3), at(3, 1), at(3, 2), at(3, 3) }).determinant(); + accum -= at(0, 1) * Matrix({ at(1, 0), at(1, 2), at(1, 3), at(2, 0), at(2, 2), at(2, 3), at(3, 0), at(3, 2), at(3, 3) }).determinant(); + accum += at(0, 2) * Matrix({ at(1, 0), at(1, 1), at(1, 3), at(2, 0), at(2, 1), at(2, 3), at(3, 0), at(3, 1), at(3, 3) }).determinant(); + accum -= at(0, 3) * Matrix({ at(1, 0), at(1, 1), at(1, 2), at(2, 0), at(2, 1), at(2, 2), at(3, 0), at(3, 1), at(3, 2) }).determinant(); + + return accum; + } + + /// Computes the inverse of a 4-by-4 matrix (ignores the optional argument) + CUTLASS_HOST_DEVICE + Matrix inverse(Element ignore = 1) const { + Matrix B = slice_2x2(0, 2); + Matrix A = slice_2x2(0, 0); + Matrix C = slice_2x2(2, 0); + Matrix D = slice_2x2(2, 2); + + Matrix D_inv = D.inverse(); + + Matrix E = (A - B * D_inv * C).inverse(); + + return Matrix::block( + E, -E * B * D_inv, + -D_inv * C * E, D_inv + D_inv * C * E * B * D_inv + ); + } + +}; + +/// Template alias for 4-by-4 matrix +template +using Matrix4x4 = Matrix; + + +/// Free funciton to infer element type from template arguments +template +CUTLASS_HOST_DEVICE Matrix4x4 make_Matrix4x4( + Element _0_0, Element _0_1, Element _0_2, Element _0_3, + Element _1_0, Element _1_1, Element _1_2, Element _1_3, + Element _2_0, Element _2_1, Element _2_2, Element _2_3, + Element _3_0, Element _3_1, Element _3_2, Element _3_3 +) { + return Matrix4x4( + _0_0, _0_1, _0_2, _0_3, + _1_0, _1_1, _1_2, _1_3, + _2_0, _2_1, _2_2, _2_3, + _3_0, _3_1, _3_2, _3_3 + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Elementwise scalar multiplication +template +CUTLASS_HOST_DEVICE +Matrix operator*(Element s, Matrix const &rhs) { + return rhs.multiply(s); +} + +/// Prints matrix to ostream +template +std::ostream & operator<<(std::ostream &out, Matrix const &rhs) { + + for (int i = 0; i < Rows; ++i) { + for (int j = 0; j < Columns; ++j) { + out << (j ? ", " : "") << rhs.at(i, j); + } + out << "\n"; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 78181ce79a..766478e085 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -515,17 +515,30 @@ struct NumericConverterClamp { using source_type = float; static_assert((platform::is_same::value || + platform::is_same::value || + platform::is_same::value || platform::is_same::value || - platform::is_same::value), + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), "Clamp is only needed for integer types"); CUTLASS_HOST_DEVICE static result_type convert(source_type const & s) { NumericConverter convert_op; + double kClamp_max, kClamp_min; - double kClamp_max = double((1U << (sizeof_bits::value - 1)) - 1); - double kClamp_min = -kClamp_max - 1; + if (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value) { + kClamp_max = double((1LLU << (sizeof_bits::value - 1)) - 1); + kClamp_min = -kClamp_max - 1; + } else { + kClamp_max = double((1LLU << (sizeof_bits::value)) - 1); + kClamp_min = 0; + } double source = s; @@ -946,6 +959,130 @@ struct NumericArrayConverter { return convert(s); } }; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericConverter convert_element_; + + result_type result; + + result[0] = convert_element_(source[0]); + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + uint32_t tmp; + + asm volatile( + "cvt.pack.sat.u8.s32.b32 %0, %2, %1, 0;\n" + : "=r"(tmp) : "r"(source[0]), "r"(source[1])); + + uint16_t out = (tmp & 0xffff); + return reinterpret_cast(out); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned out; + + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.u8.s32.b32 r4, %4, %3, 0;" + "cvt.pack.sat.u8.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(out) : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3])); + + return reinterpret_cast(out); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 4), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + #endif ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1025,12 +1162,84 @@ struct NumericArrayConverter { } }; +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + unsigned out; + + asm volatile( + "{ .reg .u32 r4;" + "cvt.pack.sat.u4.s32.b32 r4, %8, %7, 0;" + "cvt.pack.sat.u4.s32.b32 r4, %6, %5, r4;" + "cvt.pack.sat.u4.s32.b32 r4, %4, %3, r4;" + "cvt.pack.sat.u4.s32.b32 %0, %2, %1, r4;" + "}" + : "=r"(out) + : "r"(source[0]), "r"(source[1]), "r"(source[2]), "r"(source[3]), + "r"(source[4]), "r"(source[5]), "r"(source[6]), "r"(source[7])); + + return reinterpret_cast(out); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + static_assert(!(N % 8), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + + NumericArrayConverter convert_vector_; + + result_type result; + + Array *result_ptr = reinterpret_cast *>(&result); + Array const *source_ptr = reinterpret_cast const *>(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 8; ++i) { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) { + return convert(s); + } +}; + #endif // Conditional guards to enable partial specialization for packed integers ///////////////////////////////////////////////////////////////////////////////////////////////// /// FastNumericArrayConverter only works when the source is within center range. -/// Conversion operator for Array +/// Conversion operator for Array. See the comments before +/// FastLinearCombinationClamp. template struct FastNumericArrayConverter { diff --git a/include/cutlass/quaternion.h b/include/cutlass/quaternion.h new file mode 100644 index 0000000000..aef35025d3 --- /dev/null +++ b/include/cutlass/quaternion.h @@ -0,0 +1,616 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a densely packed quaternion object intended for storing data in registers and + executing quaternion operations within a CUDA or host thread. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/matrix.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/vector.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Quaternion: xi + yj + zk + w +template < + typename Element_ = float ///< element type +> +class Quaternion : public Array { +public: + + /// Logical rank of tensor index space + static int const kRank = 1; + + /// Number of elements + static int const kExtent = 4; + + /// Base class is a four-element array + using Base = Array; + + /// Element type + using Element = typename Base::Element; + + /// Reference type to an element + using Reference = typename Base::reference; + + /// Index type + using Index = int; + + /// Quaternion storage - imaginary part + static int const kX = 0; + + /// Quaternion storage - imaginary part + static int const kY = 1; + + /// Quaternion storage - imaginary part + static int const kZ = 2; + + /// Quaternion storage - real part + static int const kW = 3; + +public: + + // + // Methods + // + + /// Constructs a quaternion + CUTLASS_HOST_DEVICE + Quaternion( + Element w_ = Element(1) + ) { + Base::at(kX) = Element(0); + Base::at(kY) = Element(0); + Base::at(kZ) = Element(0); + Base::at(kW) = w_; + } + + /// Constructs a quaternion + CUTLASS_HOST_DEVICE + Quaternion( + Element x_, + Element y_, + Element z_, + Element w_ + ) { + Base::at(kX) = x_; + Base::at(kY) = y_; + Base::at(kZ) = z_; + Base::at(kW) = w_; + } + + /// Constructs a quaternion from a vector representing the imaginary part and a real number + CUTLASS_HOST_DEVICE + Quaternion( + Matrix3x1 const &imag_, + Element w_ = Element() + ) { + Base::at(kX) = imag_[0]; + Base::at(kY) = imag_[1]; + Base::at(kZ) = imag_[2]; + Base::at(kW) = w_; + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(Index idx) const { + return Base::at(idx); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Reference at(Index idx) { + return Base::at(idx); + } + + /// Accesses the x element of the imaginary part of the quaternion + CUTLASS_HOST_DEVICE + Element x() const { + return Base::at(kX); + } + + /// Accesses the x element of the imaginary part of the quaternion + CUTLASS_HOST_DEVICE + Reference x() { + return Base::at(kX); + } + + /// Accesses the y element of the imaginary part of the quaternion + CUTLASS_HOST_DEVICE + Element y() const { + return Base::at(kY); + } + + /// Accesses the y element of the imaginary part of the quaternion + CUTLASS_HOST_DEVICE + Reference y() { + return Base::at(kY); + } + + /// Accesses the z element of the imaginary part of the quaternion + CUTLASS_HOST_DEVICE + Element z() const { + return Base::at(kZ); + } + + /// Accesses the z element of the imaginary part of the quaternion + CUTLASS_HOST_DEVICE + Reference z() { + return Base::at(kZ); + } + + /// Accesses the real part of the quaternion + CUTLASS_HOST_DEVICE + Element w() const { + return Base::at(kW); + } + + /// Accesses the real part of the quaternion + CUTLASS_HOST_DEVICE + Reference w() { + return Base::at(kW); + } + + /// Returns the pure imaginary part of the quaternion as a 3-vector + CUTLASS_HOST_DEVICE + Matrix3x1 pure() const { + return Matrix3x1(x(), y(), z()); + } + + /// Returns a quaternion representation of a spatial rotation given a unit-length axis and + /// a rotation in radians. + CUTLASS_HOST_DEVICE + static Quaternion rotation( + Matrix3x1 const &axis_unit, ///< axis of rotation (assumed to be unit length) + Element theta) { ///< angular rotation in radians + + Element s = fast_sin(theta / Element(2)); + + return Quaternion( + s * axis_unit[0], + s * axis_unit[1], + s * axis_unit[2], + fast_cos(theta / Element(2)) + ); + } + + /// Returns a quaternion representation of a spatial rotation represented as a + /// unit-length rotation axis (r_x, r_y, r_z) and an angular rotation in radians + CUTLASS_HOST_DEVICE + static Quaternion rotation( + Element r_x, + Element r_y, + Element r_z, + Element theta) { ///< angular rotation in radians + + return rotation({r_x, r_y, r_z}, theta); + } + + /// Geometric rotation of a 3-element vector + CUTLASS_HOST_DEVICE + Matrix3x1 rotate(Matrix3x1 const &rhs) const { + return (*this * Quaternion(rhs, 0) * reciprocal(*this)).pure(); + } + + /// Inverse rotation operation + CUTLASS_HOST_DEVICE + Matrix3x1 rotate_inv(Matrix3x1 const &rhs) const { + return (reciprocal(*this) * Quaternion(rhs, 0) * *this).pure(); + } + + /// Rotates a 3-vector assuming this is a unit quaternion (a spinor) + CUTLASS_HOST_DEVICE + Matrix3x1 spinor(Matrix3x1 const &rhs) const { + return (*this * Quaternion(rhs, 0) * conj(*this)).pure(); + } + + /// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor) + CUTLASS_HOST_DEVICE + Matrix3x1 spinor_inv(Matrix3x1 const &rhs) const { + return (conj(*this) * Quaternion(rhs, 0) * *this).pure(); + } + + /// In-place addition + template + CUTLASS_HOST_DEVICE + Quaternion &operator+=(Quaternion const &rhs) { + *this = (*this + rhs); + return *this; + } + + /// In-place subtraction + template + CUTLASS_HOST_DEVICE + Quaternion &operator-=(Quaternion const &rhs) { + *this = (*this - rhs); + return *this; + } + + /// In-place multiplication + template + CUTLASS_HOST_DEVICE + Quaternion &operator*=(Quaternion const &rhs) { + *this = (*this * rhs); + return *this; + } + + /// Scalar multiplication + template + CUTLASS_HOST_DEVICE + Quaternion &operator*=(Element s) { + *this = (*this * s); + return *this; + } + + /// In-place Division + template + CUTLASS_HOST_DEVICE + Quaternion &operator/=(Quaternion const &rhs) { + *this = (*this / rhs); + return *this; + } + + /// In-place Division + template + CUTLASS_HOST_DEVICE + Quaternion &operator/=(Element s) { + *this = (*this / s); + return *this; + } + + /// Computes a 3x3 rotation matrix (row-major representation) + CUTLASS_HOST_DEVICE + Matrix3x3 as_rotation_matrix_3x3() const { + Matrix3x3 m( + w() * w() + x() * x() - y() * y() - z() * z(), + 2 * x() * y() - 2 * w() * z(), + 2 * x() * z() + 2 * w() * y(), + + 2 * x() * y() + 2 * w() * z(), + w() * w() - x() * x() + y() * y() - z() * z(), + 2 * y() * z() - 2 * w() * x(), + + 2 * x() * z() - 2 * w() * y(), + 2 * y() * z() + 2 * w() * x(), + w() * w() - x() * x() - y() * y() + z() * z() + ); + return m; + } + + /// Computes a 4x4 rotation matrix (row-major representation) + CUTLASS_HOST_DEVICE + Matrix4x4 as_rotation_matrix_4x4() const { + Matrix4x4 m = Matrix4x4::identity(); + m.set_slice_3x3(as_rotation_matrix_3x3()); + return m; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a quaternion that is non-zero only in its real element. +template +CUTLASS_HOST_DEVICE +Quaternion make_Quaternion( + Element w) { ///< real part + + return Quaternion(w); +} + +/// Constructs a quaternion from a vector and real +template +CUTLASS_HOST_DEVICE +Quaternion make_Quaternion( + Matrix3x1 const &imag, ///< imaginary party as a vector + Element w) { ///< real part + + return Quaternion(imag, w); +} + +/// Constructs a quaternion from a unit-length rotation axis and a rotation +/// angle in radians +template +CUTLASS_HOST_DEVICE +Quaternion make_QuaternionRotation( + Matrix3x1 const &axis_unit, ///< rotation axis (unit-length) + Element w) { ///< rotation angle in radians + + return Quaternion::rotation(axis_unit, w); +} + +/// Constructs a quaternion q = xi + yj + zk + w +template +CUTLASS_HOST_DEVICE +Quaternion make_Quaternion(Element x, Element y, Element z, Element w) { + return Quaternion(x, y, z, w); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the magnitude of the complex number +template +CUTLASS_HOST_DEVICE +Element abs(Quaternion const &q) { + return fast_sqrt(norm(q)); +} + +/// Quaternion conjugate +template +CUTLASS_HOST_DEVICE +Quaternion conj(Quaternion const &q) { + return make_Quaternion( + -q.x(), + -q.y(), + -q.z(), + q.w() + ); +} + +/// Computes the squared magnitude of the quaternion +template +CUTLASS_HOST_DEVICE +Element norm(Quaternion const &q) { + return q.x() * q.x() + q.y() * q.y() + q.z() * q.z() + q.w() * q.w(); +} + +/// Quaternion reciprocal +template +CUTLASS_HOST_DEVICE +Quaternion reciprocal(Quaternion const &q) { + + Element nsq = norm(q); + + return make_Quaternion( + -q.x() / nsq, + -q.y() / nsq, + -q.z() / nsq, + q.w() / nsq + ); +} + +/// Returns a unit-length quaternion +template +CUTLASS_HOST_DEVICE +Quaternion unit(Quaternion const &q) { + + Element rcp_mag = Element(1) / abs(q); + + return make_Quaternion( + q.x() * rcp_mag, + q.y() * rcp_mag, + q.z() * rcp_mag, + q.w() * rcp_mag + ); +} + +/// Quaternion exponential +template +CUTLASS_HOST_DEVICE +Quaternion exp(Quaternion const &q) { + + Element exp_ = fast_exp(q.w()); + Element imag_norm = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); + Element sin_norm = fast_sin(imag_norm); + + return make_Quaternion( + exp_ * q.x() * sin_norm / imag_norm, + exp_ * q.y() * sin_norm / imag_norm, + exp_ * q.z() * sin_norm / imag_norm, + exp_ * fast_cos(imag_norm) + ); +} + +/// Quaternion natural logarithm +template +CUTLASS_HOST_DEVICE +Quaternion log(Quaternion const &q) { + + Element v = fast_sqrt(q.x() * q.x() + q.y() * q.y() + q.z() * q.z()); + Element s = fast_acos(q.w() / abs(q)) / v; + + return make_Quaternion( + q.x() * s, + q.y() * s, + q.z() * s, + fast_log(q.w()) + ); +} + +/// Gets the rotation angle from a unit-length quaternion +template +CUTLASS_HOST_DEVICE +Element get_rotation_angle(Quaternion const &q_unit) { + return fast_acos(q_unit.w()) * Element(2); +} + +/// Gets the rotation axis from a unit-length quaternion +template +CUTLASS_HOST_DEVICE +Matrix3x1 get_rotation_axis(Quaternion const &q_unit) { + return q_unit.pure().unit(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Equality operator +template +CUTLASS_HOST_DEVICE +bool operator==(Quaternion const &lhs, Quaternion const &rhs) { + return lhs.x() == rhs.x() && + lhs.y() == rhs.y() && + lhs.z() == rhs.z() && + lhs.w() == rhs.w(); +} + +/// Inequality operator +template +CUTLASS_HOST_DEVICE +bool operator!=(Quaternion const &lhs, Quaternion const &rhs) { + return !(lhs == rhs); +} + +/// Quaternion scalar multiplication +template +CUTLASS_HOST_DEVICE +Quaternion operator*(Quaternion q, Element s) { + return make_Quaternion( + q.x() * s, + q.y() * s, + q.z() * s, + q.w() * s + ); +} + +/// Quaternion scalar multiplication +template +CUTLASS_HOST_DEVICE +Quaternion operator*(Element s, Quaternion const &q) { + return make_Quaternion( + s * q.x(), + s * q.y(), + s * q.z(), + s * q.w() + ); +} + +/// Quaternion scalar division +template +CUTLASS_HOST_DEVICE +Quaternion operator/(Quaternion const &q, Element s) { + return make_Quaternion( + q.x() / s, + q.y() / s, + q.z() / s, + q.w() / s + ); +} + +/// Quaternion unary negation +template +CUTLASS_HOST_DEVICE +Quaternion operator-(Quaternion const &q) { + return make_Quaternion( + -q.x(), + -q.y(), + -q.z(), + -q.w() + ); +} + +/// Quaternion addition +template +CUTLASS_HOST_DEVICE +Quaternion operator+(Quaternion const &lhs, Quaternion const &rhs) { + return make_Quaternion( + lhs.x() + rhs.x(), + lhs.y() + rhs.y(), + lhs.z() + rhs.z(), + lhs.w() + rhs.w() + ); +} + +/// Quaternion subtraction +template +CUTLASS_HOST_DEVICE +Quaternion operator-(Quaternion const &lhs, Quaternion const &rhs) { + return make_Quaternion( + lhs.x() - rhs.x(), + lhs.y() - rhs.y(), + lhs.z() - rhs.z(), + lhs.w() - rhs.w() + ); +} + +/// Quaternion product +template +CUTLASS_HOST_DEVICE +Quaternion operator*(Quaternion const &lhs, Quaternion const &rhs) { + return make_Quaternion( + lhs.w() * rhs.x() + rhs.w() * lhs.x() + lhs.y() * rhs.z() - lhs.z() * rhs.y(), + lhs.w() * rhs.y() + rhs.w() * lhs.y() + lhs.z() * rhs.x() - lhs.x() * rhs.z(), + lhs.w() * rhs.z() + rhs.w() * lhs.z() + lhs.x() * rhs.y() - lhs.y() * rhs.x(), + lhs.w() * rhs.w() - lhs.x() * rhs.x() - lhs.y() * rhs.y() - lhs.z() * rhs.z() + ); +} + +/// Quaternion division +template +CUTLASS_HOST_DEVICE +Quaternion operator/(Quaternion const &lhs, Quaternion const &rhs) { + return lhs * reciprocal(rhs); +} + +/// Quaternion scalar division +template +CUTLASS_HOST_DEVICE +Quaternion operator/(Element s, Quaternion const &q) { + return s * reciprocal(q); +} + +/// Rotates a 3-vector assuming this is a unit quaternion (a spinor). This avoids computing +/// a reciprocal. +template +CUTLASS_HOST_DEVICE +Matrix3x1 spinor_rotation( + Quaternion const &spinor, /// unit-length quaternion + Matrix3x1 const &rhs) { /// arbitrary 3-vector + + return (spinor * Quaternion(rhs, 0) * conj(spinor)).pure(); +} + +/// Inverse rotation of 3-vector assuming this is a unit quaternion (a spinor). This avoids computing +/// a reciprocal. +template +CUTLASS_HOST_DEVICE +Matrix3x1 spinor_rotation_inv( + Quaternion const &spinor, /// unit-length quaternion + Matrix3x1 const &rhs) { /// arbitrary 3-vector + + return (conj(spinor) * Quaternion(rhs, 0) * spinor).pure(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Output operators +// + +template +std::ostream &operator<<(std::ostream &out, Quaternion const &q) { + return out << q.w() << "+i" << q.x() << "+j" << q.y() << "+k" << q.z(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/real.h b/include/cutlass/real.h index 45ab1864eb..99af846b19 100644 --- a/include/cutlass/real.h +++ b/include/cutlass/real.h @@ -22,6 +22,11 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ +/** + \file + \brief This class provides helpers to support real<> and complex<> types in generic code. +*/ + #pragma once namespace cutlass { diff --git a/include/cutlass/reduction/batched_reduction.h b/include/cutlass/reduction/batched_reduction.h deleted file mode 100644 index 16132a0210..0000000000 --- a/include/cutlass/reduction/batched_reduction.h +++ /dev/null @@ -1,179 +0,0 @@ -/*************************************************************************************************** -* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. -* -* Redistribution and use in source and binary forms, with or without modification, are permitted -* provided that the following conditions are met: -* * Redistributions of source code must retain the above copyright notice, this list of -* conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, this list of -* conditions and the following disclaimer in the documentation and/or other materials -* provided with the distribution. -* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -* to endorse or promote products derived from this software without specific prior written -* permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -* -**************************************************************************************************/ -/*! \file -\brief Implements a software-pipelined efficient batched reduction. -D = alpha * Reduction(A) + beta * C -*/ -#pragma once - -#if !defined(__CUDACC_RTC__) -#include -#endif - -#include "cutlass/coord.h" -#include "cutlass/util/platform.h" -#include "cutlass/fragment.h" - -namespace cutlass { -namespace reduction { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ __launch_bounds__(batched_reduction_::Traits::kThreads, 1) void batched_reduction_kernel(typename batched_reduction_::Params params) { - // Construct the batched_reduction object - batched_reduction_ batched_reduction(params); - batched_reduction.run(); -} - -template -struct BatchedReduction { - /// This class - typedef BatchedReduction This_; - /// The traits - typedef BatchedReductionTraits_ Traits; - /// Params - typedef typename Traits::Params Params; - /// functor - typedef typename Traits::Functor Functor; - - /// ctor - CUTLASS_DEVICE BatchedReduction(Params const ¶ms_) - : params(params_), functor(params_.functorParams) {} - - /// main operation method - /// D = alpha * Reduction(A) + beta * C - CUTLASS_DEVICE void run() { -#if (__CUDA_ARCH__ >= 600) - // Swizzle the IDs of the block - typename Traits::BlockSwizzle block_swizzle; - Coord<3> threadblock_offset = - block_swizzle.get_threadblock_offset(make_Coord_from_shape()); - - int subTileSize = gridDim.x * Traits::SubTile::kW; - int tileSize = params.problem_size[1] * params.problem_size[2]; - int subTileOffset = threadblock_offset[2] + threadIdx.x * Traits::ThreadShape::kW; - - int subTileBase = 0; - - typename Traits::ScalarA inRegs[Traits::maxInReg]; - typename Traits::ScalarAccum AccumRegs[Traits::maxOutReg]; -#pragma unroll - for (int subTile = 0; subTile < tileSize; subTile += subTileSize) { - int tileOffset = subTileBase + subTileOffset; - // Init AccumRegs -#pragma unroll - for (int i = 0; i < Traits::ThreadShape::kW; i++) - AccumRegs[i] = static_cast(0.0f); - // Fetch c0 - typename Traits::ScalarAccum c0[Traits::ThreadShape::kW]; -#pragma unroll - for (int i = 0; i< Traits::ThreadShape::kW; i++) - c0[i] = static_cast(params.d_c[tileOffset + i]); - - // Fetch partial sums from A -#pragma unroll - for (int s = 0; s < Traits::ReductionSize; s++) { - int inRegOffset = s * Traits::ThreadShape::kW; - int dOffset = (s * tileSize) + tileOffset; -#pragma unroll - for (int i = 0; i< Traits::ThreadShape::kW; i++) { - inRegs[inRegOffset + i] = params.d_a[dOffset + i]; - } - } - - // Accumulate -#pragma unroll - for (int s = 0; s < Traits::ReductionSize; s++) { - int inRegOffset = s * Traits::ThreadShape::kW; -#pragma unroll - for (int i = 0; i < Traits::ThreadShape::kW; i++) { - //AccumRegs[i] = cuFma(params.alpha, inRegs[inRegOffset + i], AccumRegs[i]); - //AccumRegs[i] = params.alpha * inRegs[inRegOffset + i] + AccumRegs[i]; - AccumRegs[i] = static_cast(inRegs[inRegOffset + i]) + AccumRegs[i]; - } - } - // calling functor - functor_caller(AccumRegs, c0, AccumRegs); - - // Store AccumRegs to D -#pragma unroll - for (int i = 0; i < Traits::ThreadShape::kW; i++) { - params.d_d[tileOffset + i] = static_cast(AccumRegs[i]); - } - - // Advance sub-tile pointer - subTileBase += subTileSize; - } // end for loop -#endif //#if (__CUDA_ARCH__ >= 600) - } - - template - CUTLASS_DEVICE void functor_caller(typename Traits::ScalarAccum const *accum, typename Traits::ScalarAccum const *old, typename Traits::ScalarAccum *output) { - if (ThreadShapeMultiple2 == true) { -#pragma unroll - for (int i = 0; i < Traits::ThreadShape::kW / 2; i++) { - functor.template evaluate(&accum[2 * i], &old[2 * i], &output[2 * i]); - } - } - else { -#pragma unroll - for (int i = 0; i < Traits::ThreadShape::kW; i++) { - functor.template evaluate(&accum[i], &old[i], &output[i]); - } - } - } - - // - // Static function members - // -#if !defined(__CUDACC_RTC__) - /// Launch the kernel. - static __host__ cudaError_t launch(Params const& params, - cudaStream_t stream = cudaStreamDefault) { - // Setup the grid. - typename Traits::BlockSwizzle block_swizzle; - dim3 grid = block_swizzle.get_grid_layout(params.problem_size, - make_Coord_from_shape()); - - dim3 block; - block.x = Traits::kThreads; - batched_reduction_kernel<<>>(params); - return cudaGetLastError(); - } -#endif - - // - // Data members - // - - /// The params. - Params const& params; - // The functor. - Functor functor; -}; - -} // namespace reduction -} // namespace cutlass diff --git a/include/cutlass/reduction/batched_reduction_traits.h b/include/cutlass/reduction/batched_reduction_traits.h deleted file mode 100644 index 46157dc703..0000000000 --- a/include/cutlass/reduction/batched_reduction_traits.h +++ /dev/null @@ -1,192 +0,0 @@ -/*************************************************************************************************** -* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. -* -* Redistribution and use in source and binary forms, with or without modification, are permitted -* provided that the following conditions are met: -* * Redistributions of source code must retain the above copyright notice, this list of -* conditions and the following disclaimer. -* * Redistributions in binary form must reproduce the above copyright notice, this list of -* conditions and the following disclaimer in the documentation and/or other materials -* provided with the distribution. -* * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used -* to endorse or promote products derived from this software without specific prior written -* permission. -* -* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR -* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND -* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE -* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; -* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -* STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -* -**************************************************************************************************/ -/*! \file -\brief Defines structural properties of complete batched reduction. -D = alpha * Reduction(A) + beta * C -*/ -#pragma once -#include "cutlass/cutlass.h" -#include "cutlass/shape.h" -#include "cutlass/reduction/threadblock_swizzle.h" -#include "cutlass/reduction/batched_reduction.h" -#include "cutlass/gemm/linear_scaling.h" - -namespace cutlass { -namespace reduction { - -/* -OutputTile defines the work load per thread block -Subtile defines the work load per thread block per iteration -OutputTile / Subtile = number of iterations within a kernel -ThreadShape defines the work load per thread -Subtile / ThreadShape = number of threads per thread block -*/ -template < - /// The scalar type for A - typename ScalarA_, - /// The scalar type for C - typename ScalarC_, - /// The scalar type for D - typename ScalarD_, - /// the scalar type for alpha, - typename ScalarAlphaBeta_, - /// The scalar type for accumulator - typename ScalarAccum_, - /// Reduction work load per batch - int ReductionSize_ = 1, - /// The output tile, work load per thread block, - typename OutputTile_ = Shape<1, 1, 128>, - /// The subtile - typename SubTile_ = Shape<1, 1, 64>, - /// Work load per thread, per subtile - typename ThreadShape_ = Shape<1, 1, 2>, - /// The index - typename Index_ = int, - /// The block swizzle to reorganize the grid. - typename BlockSwizzle_ = DefaultBlockSwizzle, - /// The input register vector size in kernel - int maxInReg_ = 160, - /// The output register vector size in kernel - int maxOutReg_ = 64, - /// The functor that will be executed at the end - typename Functor_ = typename cutlass::gemm::LinearScaling > -> -struct BatchedReductionTraits { - /// - typedef BatchedReductionTraits This_; - /// The struct that consumes this Traits - typedef typename cutlass::reduction::BatchedReduction KernelClass; - /// - typedef OutputTile_ OutputTile; - /// - typedef SubTile_ SubTile; - /// - typedef ThreadShape_ ThreadShape; - /// The input pointer type - typedef ScalarA_ ScalarA; - /// - typedef ScalarC_ ScalarC; - /// The output pointer type - typedef ScalarD_ ScalarD; - /// The alpha beta type - typedef ScalarAlphaBeta_ ScalarAlphaBeta; - /// The type for accumulation - typedef ScalarAccum_ ScalarAccum; - /// The index - typedef Index_ Index; - /// The thread block swizzle - typedef BlockSwizzle_ BlockSwizzle; - /// - static const int ReductionSize = ReductionSize_; - /// check if threadShape is multiple of 2. - static const bool ThreadShapeMultiple2 = (ThreadShape::kW % 2 == 0); - /// - typedef Functor_ Functor; - /// Parameteres object constructable on the host - /// The number of threads per thread block. can be deduced - static int const kThreads = SubTile::kW / ThreadShape::kW; - // - static int const maxInReg = maxInReg_; - // - static int const maxOutReg = maxOutReg_; - // - static_assert(SubTile::kW % ThreadShape::kW == 0, "cannot evenly distribute work load among threads"); - // - static_assert(kThreads % 32 == 0, "threads per threadblock is not multiple of 32"); - // - static_assert(OutputTile::kW % SubTile::kW == 0, "cannot evenly distribute work load among iterations"); - // - static_assert(ReductionSize * ThreadShape::kW <= maxInReg, "ReductionSize * ThreadShape::kW should not be bigger than maxInReg"); - // - static_assert(ThreadShape::kW <= maxOutReg, "ThreadShape::kW should not be bigger than maxOutReg"); - - struct Params { - /// The dimension of output tensor - Coord<3> problem_size; - /// The alpha - ScalarAlphaBeta alpha; - /// The beta - ScalarAlphaBeta beta; - /// stride between two element that will be sumed - long long int reduction_stride; - // - ScalarA const *d_a; - // - Index lda; - // - ScalarC const *d_c; - // - Index ldc; - // - ScalarD *d_d; - // - Index ldd; - /// The functor params. - typename Functor::Params functorParams; - /// Initialize the parameters for 2D output tensor - CUTLASS_HOST_DEVICE int initialize(Index m_, - Index n_, - ScalarAlphaBeta alpha_, - ScalarAlphaBeta beta_, - long long int reduction_stride_, - ScalarA const *d_a_, - Index lda_, - ScalarC const *d_c_, - Index ldc_, - ScalarD *d_d_, - Index ldd_){ - problem_size = make_Coord(1, n_, m_); - alpha = alpha_; - beta = beta_; - reduction_stride = reduction_stride_; - d_a = d_a_; - lda = lda_; - d_c = d_c_; - d_d = d_d_; - ldc = ldc_; - ldd = ldd_; - - functorParams.initialize(alpha_, beta_); - - return 0; - } - }; - -}; -} // namespace reduction -} // namespace cutlass diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 5714fbd2fd..3d6a43b952 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -77,6 +77,18 @@ bool relatively_equal(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) { return (a == b); } +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(int2b_t a, int2b_t b, int2b_t, int2b_t) { + return (a == b); +} + +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(uint2b_t a, uint2b_t b, uint2b_t, uint2b_t) { + return (a == b); +} + template <> CUTLASS_HOST_DEVICE bool relatively_equal(int4b_t a, int4b_t b, int4b_t, int4b_t) { diff --git a/include/cutlass/tensor_coord.h b/include/cutlass/tensor_coord.h index d7a6d0df6a..b60bc11262 100644 --- a/include/cutlass/tensor_coord.h +++ b/include/cutlass/tensor_coord.h @@ -165,4 +165,146 @@ struct Tensor4DCoord : public Coord<4> { //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a canonical 5D coordinate used by tensor operations. +struct Tensor5DCoord : public Coord<5> { + + /// Base class + using Base = Coord<5>; + + /// Index type + using Index = typename Base::Index; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Batch dimension + static int const kN = 0; + + /// Depth dimension + static int const kD = 1; + + /// Height dimension + static int const kH = 2; + + /// Width dimension + static int const kW = 3; + + /// Channels dimension + static int const kC = 4; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Tensor5DCoord() { } + + /// Constructs from Coord<5> + CUTLASS_HOST_DEVICE + Tensor5DCoord(Coord<5> const &coord): Base(coord) { } + + /// Helper to construct from N, D, H, W, and C. + CUTLASS_HOST_DEVICE + Tensor5DCoord(Index n, Index d, Index h, Index w, Index c): Base(make_Coord(n, d, h, w, c)) { } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & n() const { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index const & d() const { return this->at(kD); } + + /// Returns the batch of the coordinate + CUTLASS_HOST_DEVICE + Index & d() { return this->at(kD); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index const & h() const { return this->at(kH); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index & h() { return this->at(kH); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index const & w() const { return this->at(kW); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index & w() { return this->at(kW); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index const & c() const { return this->at(kC); } + + /// Returns the channel of the coordinate + CUTLASS_HOST_DEVICE + Index & c() { return this->at(kC); } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + Tensor5DCoord operator+(Base const& b) const { + return Tensor5DCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + Tensor5DCoord operator-(Base const& b) const { + return Tensor5DCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + Tensor5DCoord operator*(Base const& b) const { + return Tensor5DCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + Tensor5DCoord operator/(Base const& b) const { + return Tensor5DCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + Tensor5DCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass diff --git a/include/cutlass/tensor_view.h b/include/cutlass/tensor_view.h index a9cf569de4..fdbee1055e 100644 --- a/include/cutlass/tensor_view.h +++ b/include/cutlass/tensor_view.h @@ -176,6 +176,12 @@ class TensorView : public TensorRef { CUTLASS_HOST_DEVICE Index extent(int dim) const { return extent_.at(dim); } + /// Returns the number of logical elements + CUTLASS_HOST_DEVICE + LongIndex size() const { + return extent_.product(); + } + /// Determines whether a location is within a tensor CUTLASS_HOST_DEVICE bool contains(TensorCoord const& coord) const { diff --git a/include/cutlass/tfloat32.h b/include/cutlass/tfloat32.h index 64dc391497..2d28851299 100644 --- a/include/cutlass/tfloat32.h +++ b/include/cutlass/tfloat32.h @@ -82,7 +82,7 @@ struct alignas(4) tfloat32_t { /// Default constructor CUTLASS_HOST_DEVICE - tfloat32_t() { } + tfloat32_t() : storage(0) { } /// Floating-point conversion - round toward nearest even CUTLASS_HOST_DEVICE diff --git a/include/cutlass/matrix_traits.h b/include/cutlass/trace.h similarity index 79% rename from include/cutlass/matrix_traits.h rename to include/cutlass/trace.h index cf7002a42a..39ffa2968c 100644 --- a/include/cutlass/matrix_traits.h +++ b/include/cutlass/trace.h @@ -23,29 +23,31 @@ * **************************************************************************************************/ /*! \file - \brief Defines properties of matrices used to denote layout and operands to GEMM kernels. -*/ -#pragma once - -#include "cutlass/cutlass.h" - -namespace cutlass { + \brief Helpers for optionally tracing through code when debugging. -//////////////////////////////////////////////////////////////////////////////////////////////////// + This file is to be included after all other headers. +*/ -enum class MatrixLayout { - kColumnMajor, - kRowMajor -}; +#pragma once //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Transformation applied to matrix operands -enum class MatrixTransform { - kNone, /// no operation - kTranspose /// transpose operation -}; +// Tracing options +#ifndef CUTLASS_DEBUG_TRACE_LEVEL +#define CUTLASS_DEBUG_TRACE_LEVEL 0 +#endif + +#if CUTLASS_DEBUG_TRACE_LEVEL +#include +#include "cutlass/core_io.h" +#if defined(__CUDA_ARCH__) +#define CUTLASS_TRACE_HOST(x) +#else +#define CUTLASS_TRACE_HOST(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } +#endif +#else +#define CUTLASS_TRACE_HOST(x) +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass diff --git a/include/cutlass/transform/pitch_linear_thread_map.h b/include/cutlass/transform/pitch_linear_thread_map.h index 812dbd772b..de21ede4ea 100644 --- a/include/cutlass/transform/pitch_linear_thread_map.h +++ b/include/cutlass/transform/pitch_linear_thread_map.h @@ -44,13 +44,17 @@ namespace transform { //////////////////////////////////////////////////////////////////////////////// -/// Strip-mines a pitch-linear tile among a given number of threads, first along the contiguous -/// dimension then along the strided dimension. +/// Strip-mines a pitch-linear tile among a given number of threads, first along +/// the contiguous dimension then along the strided dimension. /// -/// The tile must be divisible by the thread count such that all threads may execute the same -/// number of iterations with the same delta to exhaustively cover the tile. +/// The tile must be divisible by the thread count such that all threads may +/// execute the same number of iterations with the same delta to exhaustively +/// cover the tile. /// /// This class satisfies the "RegularThreadMapping" concept. +/// +/// This ThreadMap is used by SIMT kernels and operand E of the sparse tensor +/// kernels. template < typename Shape_, int Threads, @@ -96,16 +100,17 @@ struct PitchLinearStripminedThreadMap { /// Number of iterations by each thread using Iterations = typename platform::conditional< - Threads >= Detail::ShapeVec::kContiguous, - layout::PitchLinearShape< - 1, - (Threads >= Detail::ShapeVec::kContiguous ? Detail::ShapeVec::kStrided / (kThreads / Detail::ShapeVec::kContiguous) : 0) - >, - layout::PitchLinearShape< - Detail::ShapeVec::kContiguous / kThreads, - Detail::ShapeVec::kStrided - > - >::type; + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + // Redo the comparison here to work around divide by zero compiler + // error. The compiler evaluates both path of platform::conditional. + (Threads >= Detail::ShapeVec::kContiguous + ? Detail::ShapeVec::kStrided / + (kThreads / Detail::ShapeVec::kContiguous) + : 0)>, + layout::PitchLinearShape>::type; /// Interval between accesses along each dimension of the tensor's logical coordinate space /// (in units of Elements) @@ -125,13 +130,13 @@ struct PitchLinearStripminedThreadMap { /// (in units of Elements) CUTLASS_HOST_DEVICE static TensorCoord initial_offset(int thread_id) { - return TensorCoord( (thread_id % Detail::ShapeVec::kContiguous) * kElementsPerAccess, thread_id / Detail::ShapeVec::kContiguous); } }; +/// This ThreadMap is used by GEMV template < typename Shape, int Threads, @@ -196,6 +201,8 @@ struct PitchLinearTilePolicyStripminedThreadStrided /// Policy defining a warp-raked arrangement in which a shape is partitioned into contiguous /// elements. +/// +/// This ThreadMap is used by tensor core kernels. template < typename Shape_, int Threads, @@ -241,6 +248,14 @@ struct PitchLinearWarpRakedThreadMap { Shape::kStrided >; + static_assert( + !(ShapeInAccesses::kContiguous % WarpThreadArrangement::kContiguous), + "ShapeInAccesses must be divisible by WarpThreadArrangement."); + + static_assert( + !(ShapeInAccesses::kStrided % WarpThreadArrangement::kStrided), + "ShapeInAccesses must be divisible by WarpThreadArrangement."); + // compute number of warp-level accesses total using WarpAccessIterations = layout::PitchLinearShape< ShapeInAccesses::kContiguous / WarpThreadArrangement::kContiguous, @@ -672,16 +687,17 @@ struct PitchLinear2DThreadTileStripminedThreadMap = Detail::ShapeVec::kContiguous, - layout::PitchLinearShape< - 1, - (Threads >= Detail::ShapeVec::kContiguous ? Detail::ShapeVec::kStrided / (kThreads / Detail::ShapeVec::kContiguous) : 0) - >, - layout::PitchLinearShape< - Detail::ShapeVec::kContiguous / kThreads, - Detail::ShapeVec::kStrided - > - >::type; + Threads >= Detail::ShapeVec::kContiguous, + layout::PitchLinearShape< + 1, + // Redo the comparison here to work around divide by zero compiler + // error. The compiler evaluates both path of platform::conditional. + (Threads >= Detail::ShapeVec::kContiguous + ? Detail::ShapeVec::kStrided / + (kThreads / Detail::ShapeVec::kContiguous) + : 0)>, + layout::PitchLinearShape>::type; /// Interval between accesses along each dimension of the tensor's logical coordinate space /// (in units of Elements) diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h index c77a09ffbd..7e34b546be 100644 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator.h @@ -292,7 +292,7 @@ class PredicatedTileAccessIterator(access_ptr + access_offset); return reinterpret_cast(access_byte_ptr + byte_offset_); diff --git a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h index c3f0b5249b..2dcd57d658 100644 --- a/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h +++ b/include/cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h @@ -49,7 +49,8 @@ namespace threadblock { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Regular tile iterator specialized for pitch-linear +/// Regular tile iterator specialized for pitch-linear. This one is used by 2-stage SIMT kernels +/// and sparse tensor core meta data. template < typename Shape_, typename Element_, @@ -139,7 +140,8 @@ class RegularTileIterator f64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm80.cu) | | **TensorOp** | 80 | 11.0+ | `cf32 * cf32 + cf32 => cf32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf32n_cf32t_cf32t_tensor_op_tf32_f32_sm80.cu) | | **TensorOp** | 80 | 11.0+ | `cf64 * cf64 + cf64 => cf64` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_sm80.cu), [Gaussian 3m](/test/unit/gemm/device/gemm_cf64n_cf64t_cf64t_tensor_op_f64_gaussian_sm80.cu) | +| **SpTensorOp** | 80 | 11.1+ | `f16 * f16 + f32 => {f16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80 | 11.1+ | `bf16 * bf16 + f32 => {bf16, f32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80 | 11.1+ | `tf32 * tf32 + f32 => f32` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu) | +| **SpTensorOp** | 80 | 11.1+ | `s8 * s8 + s32 => {s8, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu) | +| **SpTensorOp** | 80 | 11.1+ | `s4 * s4 + s32 => {s4, s32}` | {N,T} x {N,T} => {N,T} | [example](/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu) | + + ## Warp-level Matrix Multiply with Tensor Cores @@ -53,6 +60,11 @@ The following table summarizes supported warp level shapes for each TensorOp ins | **TensorOp** | 16-by-8-by-64 | 32x32x64, 32x64x64, 64x32x64, 64x64x64 | | **TensorOp** | 8-by-8-by-128 | 32x32x128, 32x64x128, 64x32x128, 64x64x128 | | **TensorOp** | 16-by-8-by-256 | 32x32x256, 32x64x256, 64x32x256, 64x64x256 | +| **SpTensorOp** | 16-by-8-by-16 | 64x64x16, 64x32x16, 32x64x16, 32x32x16 | +| **SpTensorOp** | 16-by-8-by-32 | 64x64x32, 64x32x32, 32x64x32, 32x32x32 | +| **SpTensorOp** | 16-by-8-by-64 | 64x64x64, 64x32x64, 32x64x64, 32x32x64 | +| **SpTensorOp** | 16-by-8-by-128 | 64x64x128, 64x32x128, 32x64x128, 32x32x128 | + TensorOp instructions depend on a permuted shared memory layout that can be efficiently loaded from. The following tables summarize the destination shared memory layout that @@ -154,6 +166,40 @@ from global memory with layout specified in the column "GMEM Layout." | **C** | `int32_t` | `RowMajor` | `RowMajor` | +**SpTensorOp 16-by-8-by-16.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|------------------------------------| +| **A** | `tfloat32_t` | `RowMajor` | `RowMajorTensorOpCrosswise<32, 32>` | +| **B** | `tfloat32_t` | `ColumnMajor` | `ColumnMajorTensorOpCrosswise<32, 32>`| +| **C** | `float` | `RowMajor` | `RowMajor` | + +**SpTensorOp 16-by-8-by-32.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|---------------------------------------| +| **A** | `half_t` | `RowMajor` | `RowMajorTensorOpCrosswise<16, 64>` | +| **B** | `half_t` | `ColumnMajor` | `ColumnMajorTensorOpCrosswise<16, 64>`| +| **C** | `float` | `RowMajor` | `RowMajor` | + +**SpTensorOp 16-by-8-by-64.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|---------------------------------------| +| **A** | `int8_t` | `RowMajor` | `RowMajorTensorOpCrosswise<8, 128>` | +| **B** | `int8_t` | `ColumnMajor` | `ColumnMajorTensorOpCrosswise<8, 128>`| +| **C** | `int32_t` | `RowMajor` | `RowMajor` | + +**SpTensorOp 16-by-8-by-128.** + +|**Operand**|**Element** | **GMEM Layout** | **SMEM Layout** | +|-----------|--------------|-----------------|------------------------------------| +| **A** | `int4b_t` | `RowMajor` | `RowMajorTensorOpCrosswise<4, 256>` | +| **B** | `int4b_t` | `ColumnMajor` | `ColumnMajorTensorOpCrosswise<4, 256>`| +| **C** | `int32_t` | `RowMajor` | `RowMajor` | + + + ## Warp-level Matrix Multiply with CUDA WMMA API The following table summarizes supported warp level shapes for each WmmaTensorOp instruction. diff --git a/media/docs/gemm_api.md b/media/docs/gemm_api.md index 759b1cd417..fec32a0451 100644 --- a/media/docs/gemm_api.md +++ b/media/docs/gemm_api.md @@ -503,6 +503,33 @@ struct Mma; } // namespace cutlass ``` +## Efficient Epilogue + +CUTLASS GEMM operators perform mma followed by epilogue operation similar +to cuBLAS. CUTLASS implements an efficient row-major epilogue. Thus, to achieve +column-major GEMM, operands A & B are transposed and swapped. + +To enable efficient row-major epilogue for both row-major and column-major output layout, +CUTLASS' device-level GEMM operators `cutlass::device::Gemm` and `cutlass::device::GemmUniversal` +provide two template definitions: +- (a) [General definition](/include/cutlass/gemm/device/gemm.h#L217) +- (b) [Specialized definition for column-major source/output](/include/cutlass/gemm/device/gemm.h#L545) + +Efficient row-major epilogue for: +- (i) GEMM operator on row-major source/output uses template (a). It runs row-major GEMM and +an efficient row-major epilogue. +- (ii) GEMM operator on column-major source/output uses template (b). It transposes and swaps +operands A and B to enable efficient epilogue. `A x B = C => Transpose(B) x Transpose(A) = Transpose(C)`. +For column-major source (C) matrix, Transpose(C) is row-major, and efficient epilogue works on +row-major. + +Note that cuBLAS typically expects a column-major source (C) and output matrix (D). Thus, +CUTLASS library only instantiates and generates GEMM operatos with column-major layout. However, +CUTLASS by itself can run both row-major and column-major output layouts for all combinations +of input layouts. Thus, CUTLASS supports the following layout combinations for input and output layouts: + +- `{N,T} x {N,T} => {N,T}` - NN, TN, TN, TT GEMM for both row-major and column-major output + ## Instruction-level operations CUTLASS defines a template-based interface to Tensor Core operations to avoid resorting diff --git a/media/docs/profiler.md b/media/docs/profiler.md index 7d2356c558..dd1f62a7c9 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -15,7 +15,7 @@ $ make cutlass_profiler -j To limit compilation time, only one tile size (128x128) is instantiated for each data type, math instruction, and layout. To instantiate all sizes, set the following environment variable when running CMake from an empty `build/` directory. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all -DCUTLASS_UNITY_BUILD_ENABLED=ON +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all -DCUTLASS_UNITY_BUILD_ENABLED=ON ... $ make cutlass_profiler -j ``` diff --git a/media/docs/quickstart.md b/media/docs/quickstart.md index 082b4c10b4..427fe13c66 100644 --- a/media/docs/quickstart.md +++ b/media/docs/quickstart.md @@ -23,6 +23,15 @@ $ mkdir build && cd build $ cmake .. -DCUTLASS_NVCC_ARCHS=80 # compiles for NVIDIA Ampere GPU architecture ``` +If your goal is strictly to build only the CUTLASS Profiler and to minimize compilation time, we suggest +executing the following CMake command in an empty `build/` directory. +```bash +$ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_ENABLE_TESTS=OFF -DCUTLASS_UNITY_BUILD_ENABLED=ON +``` + +This reduces overall compilation time by excluding unit tests and enabling the unit build. + + ## Build and run the CUTLASS Profiler From the `build/` directory created above, compile the the CUTLASS Profiler. @@ -403,7 +412,7 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=sgemm Compling only the kernels desired reduces compilation time. To instantiate kernels of all tile sizes, data types, and alignment constraints, specify -`-DCUTLASS_LIBRARY_KERNELS=all` when running `cmake`. +`-DCUTLASS_LIBRARY_KERNELS=all` when running `cmake`. Several recipes are defined below for convenience. They may be combined as a comma-delimited list. @@ -416,8 +425,7 @@ $ cmake .. -DCUTLASS_NVCC_ARCHS=80 -DCUTLASS_LIBRARY_KERNELS=tensorop*gemm the "unity build" instantiates multiple kernel instances in each compilation unit, thereby reducing binary size and avoiding linker limitations on some platforms. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all \ - -DCUTLASS_UNITY_BUILD_ENABLED=ON +$ cmake .. -DCUTLASS_NVCC_ARCHS="70;75;80" -DCUTLASS_LIBRARY_KERNELS=all -DCUTLASS_UNITY_BUILD_ENABLED=ON ``` **Example.** All GEMM kernels targeting Turing Tensor Cores. diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 610eee0112..52368a346a 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -103,6 +103,7 @@ set(SUBDIRS transform epilogue reduction + util ) if(TARGET nvidia::nvrtc AND TARGET nvidia::cuda_driver) diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index d72f42fb03..19282035f5 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -27,6 +27,8 @@ cutlass_test_unit_add_executable( bfloat16.cu tfloat32.cu complex.cu + quaternion.cu + matrix.cu predicate_vector.cu tensor_ref.cu tensor_view.cu diff --git a/test/unit/core/bfloat16.cu b/test/unit/core/bfloat16.cu index 9fa99ebb7f..d33ff2cc3c 100644 --- a/test/unit/core/bfloat16.cu +++ b/test/unit/core/bfloat16.cu @@ -142,6 +142,9 @@ TEST(bfloat16_t, host_conversion) { EXPECT_TRUE(static_cast(y) == f); } + // Try out default-ctor (zero initialization of primitive proxy type) + EXPECT_TRUE(cutlass::bfloat16_t() == 0.0_bf16); + // Try out user-defined literals EXPECT_TRUE(cutlass::bfloat16_t(7) == 7_bf16); EXPECT_TRUE(7 == static_cast(7_bf16)); diff --git a/test/unit/core/complex.cu b/test/unit/core/complex.cu index 9f70708d37..003762f719 100644 --- a/test/unit/core/complex.cu +++ b/test/unit/core/complex.cu @@ -23,15 +23,16 @@ * **************************************************************************************************/ /*! \file - \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types - and is safe to use in a union. + \brief CUTLASS host-device template for complex numbers supporting all CUTLASS numeric types. */ +// Standard Library's std::complex used for reference checking +#include + #include "../common/cutlass_unit_test.h" #include "cutlass/complex.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/util/device_memory.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -82,4 +83,70 @@ TEST(complex, f16_to_f32_conversion) { dest.real() == 1.5f && dest.imag() == -1.25f); } +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { + + /// Thorough testing for basic complex math operators. Uses std::complex as a reference. + template + struct ComplexOperators { + ComplexOperators() { + for (int ar = -N; ar <= N; ++ar) { + for (int ai = -N; ai <= N; ++ai) { + for (int br = -N; br <= N; ++br) { + for (int bi = -N; bi <= N; ++bi) { + + cutlass::complex Ae(T(ar) / T(M), T(ai) / T(M)); + cutlass::complex Be(T(br) / T(M), T(bi) / T(M)); + + std::complex Ar(T(ar) / T(M), T(ai) / T(M)); + std::complex Br(T(br) / T(M), T(bi) / T(M)); + + cutlass::complex add_e = Ae + Be; + cutlass::complex sub_e = Ae - Be; + cutlass::complex mul_e = Ae * Be; + + std::complex add_r = (Ar + Br); + std::complex sub_r = (Ar - Br); + std::complex mul_r = (Ar * Br); + + EXPECT_EQ(real(add_e), real(add_r)); + EXPECT_EQ(imag(add_e), imag(add_r)); + + EXPECT_EQ(real(sub_e), real(sub_r)); + EXPECT_EQ(imag(sub_e), imag(sub_r)); + + EXPECT_EQ(real(mul_e), real(mul_r)); + EXPECT_EQ(imag(mul_e), imag(mul_r)); + + if (!(br == 0 && bi == 0)) { + + cutlass::complex div_e = Ae / Be; + std::complex div_r = Ar / Br; + + T const kRange = T(0.001); + + EXPECT_NEAR(real(div_e), real(div_r), kRange); + EXPECT_NEAR(imag(div_e), imag(div_r), kRange); + } + } + } + } + } + } + }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(complex, host_float) { + test::ComplexOperators test; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(complex, host_double) { + test::ComplexOperators test; +} + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/half.cu b/test/unit/core/half.cu index be5e9b433d..dad1f97a79 100644 --- a/test/unit/core/half.cu +++ b/test/unit/core/half.cu @@ -50,6 +50,9 @@ TEST(half_t, host_conversion) { EXPECT_TRUE(static_cast(y) == f); } + // Try out default-ctor (zero initialization of primitive proxy type) + EXPECT_TRUE(cutlass::half_t() == 0.0_hf); + // Try out user-defined literals EXPECT_TRUE(cutlass::half_t(7) == 7_hf); EXPECT_TRUE(7 == static_cast(7_hf)); diff --git a/test/unit/core/matrix.cu b/test/unit/core/matrix.cu new file mode 100644 index 0000000000..f012fe9f87 --- /dev/null +++ b/test/unit/core/matrix.cu @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Unit tests for the small matrix class. +*/ + +#include + +#include "../common/cutlass_unit_test.h" + +#include "cutlass/matrix.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Matrix, elementwise_add) { + + using Matrix4x4 = cutlass::Matrix4x4; + + Matrix4x4 A = { + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16 + }; + + Matrix4x4 B = A.transpose(); + + Matrix4x4 C = A.add(B * 2.125f); + + bool passed = true; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float got = C.at(i, j); + float expected = A.at(i, j) + A.at(j, i) * 2.125f; + if (got != expected) { + passed = false; + } + } + } + EXPECT_TRUE(passed); + if (!passed) { + std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << std::endl; + } +} + +TEST(Matrix, elementwise_multiply) { + + using Matrix4x4 = cutlass::Matrix4x4; + + Matrix4x4 A = { + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16 + }; + + Matrix4x4 B = A.transpose(); + + Matrix4x4 C = A.multiply(B); + + bool passed = true; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float got = C.at(i, j); + float expected = A.at(i, j) * A.at(j, i); + if (got != expected) { + passed = false; + } + } + } + EXPECT_TRUE(passed); + if (!passed) { + std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << std::endl; + } +} + +TEST(Matrix, product_4x4_overloads) { + + using Matrix4x4 = cutlass::Matrix4x4; + + Matrix4x4 A = { + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16 + }; + + Matrix4x4 B = { + -1, -2, 0, 4, + 1, 2, 1, 1, + 3, 2, 1, 1, + 1, 0, 8, 2 + }; + + Matrix4x4 C = Matrix4x4::identity(); + + Matrix4x4 D = A * B + C; + + bool passed = true; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float got = D.at(i, j); + float expected = (i == j ? 1.0f : 0); + for (int k = 0; k < 4; ++k) { + expected += A.at(i, k) * B.at(k, j); + } + if (got != expected) { + passed = false; + } + } + } + + EXPECT_TRUE(passed); + if (!passed) { + std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << "\n\nD:\n" << D << std::endl; + } +} + + +TEST(Matrix, product_4x4) { + + using Matrix4x4 = cutlass::Matrix4x4; + + Matrix4x4 A = { + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16 + }; + + Matrix4x4 B = { + -1, -2, 0, 4, + 1, 2, 1, 1, + 3, 2, 1, 1, + 1, 0, 8, 2 + }; + + Matrix4x4 C = Matrix4x4::identity(); + + // Compute product with optional source accumulator + Matrix4x4 D = A.product(B, C); + + bool passed = true; + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float got = D.at(i, j); + float expected = (i == j ? 1.0f : 0.0f); + for (int k = 0; k < 4; ++k) { + expected += A.at(i, k) * B.at(k, j); + } + if (got != expected) { + passed = false; + } + } + } + + EXPECT_TRUE(passed); + if (!passed) { + std::cout << "A:\n" << A << "\n\nB:\n" << B << "\n\nC:\n" << C << "\n\nD:\n" << D << std::endl; + } + + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + float c = (i == j ? 1.0f : 0.0f); + EXPECT_TRUE(A.row(i).dot(B.column(j)) + c == D.at(i, j)); + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/core/quaternion.cu b/test/unit/core/quaternion.cu new file mode 100644 index 0000000000..69ce928aec --- /dev/null +++ b/test/unit/core/quaternion.cu @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for the CUTLASS Quaternion template class. +*/ + +#include "../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/core_io.h" +#include "cutlass/quaternion.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/constants.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +static float const half_pi = cutlass::constants::half_pi(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Quaternion, add_f32) { + + cutlass::Quaternion q0(1, 1, 1, 1); + cutlass::Quaternion q1(0, 0, 0, 2); + + cutlass::Quaternion q2 = q0 + q1; + + EXPECT_TRUE( + q2.x() == 1 && + q2.y() == 1 && + q2.z() == 1 && + q2.w() == 3 + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Quaternion, rotation) { + + cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); + cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi) * 2.0f; + cutlass::Matrix3x1 v = q.rotate(x); + + float epsilon = 0.001f; + + EXPECT_TRUE( + std::abs(v.at(0)) < epsilon && + std::abs(v.at(1)) > (1 - epsilon) && + std::abs(v.at(2)) < epsilon + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Quaternion, rotation_inv) { + + cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); + cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi) * 2.0f; + cutlass::Matrix3x1 v = q.rotate(x); + + float epsilon = 0.001f; + + EXPECT_TRUE( + std::abs(v.at(0)) < epsilon && + std::abs(-v.at(1)) > (1 - epsilon) && + std::abs(v.at(2)) < epsilon + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Quaternion, spinor_rotation) { + + cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); + cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); + cutlass::Matrix3x1 v = cutlass::spinor_rotation(q, x); + + float epsilon = 0.001f; + + EXPECT_TRUE( + std::abs(v.at(0)) < epsilon && + std::abs(v.at(1)) > (1 - epsilon) && + std::abs(v.at(2)) < epsilon + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Quaternion, spinor_rotation_inv) { + + cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); + cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); + cutlass::Matrix3x1 v = cutlass::spinor_rotation_inv(q, x); + + float epsilon = 0.001f; + + EXPECT_TRUE( + std::abs(v.at(0)) < epsilon && + std::abs(-v.at(1)) > (1 - epsilon) && + std::abs(v.at(2)) < epsilon + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Quaternion, as_rotation_matrix3x3) { + + cutlass::Matrix3x1 x(1.0f, 0.0f, 0.0f); + cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); + cutlass::Matrix3x1 v = q.as_rotation_matrix_3x3().product(x); + + float epsilon = 0.001f; + + EXPECT_TRUE( + std::abs(v.at(0)) < epsilon && + std::abs(v.at(1)) > (1 - epsilon) && + std::abs(v.at(2)) < epsilon + ); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Quaternion, as_rotation_matrix4x4) { + + cutlass::Matrix4x1 x(1.0f, 0.0f, 0.0f, 1.0f); + cutlass::Quaternion q = cutlass::Quaternion::rotation(0, 0, 1, half_pi); + cutlass::Matrix4x1 v = q.as_rotation_matrix_4x4().product(x); + + float epsilon = 0.001f; + + EXPECT_TRUE( + std::abs(v.at(0)) < epsilon && + std::abs(v.at(1)) > (1 - epsilon) && + std::abs(v.at(2)) < epsilon && + std::abs(v.at(3)) > (1 - epsilon) + ); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/core/tensor_view.cu b/test/unit/core/tensor_view.cu index b35fc426b8..684ca5b0f2 100644 --- a/test/unit/core/tensor_view.cu +++ b/test/unit/core/tensor_view.cu @@ -114,9 +114,9 @@ TEST(TensorView, contiguous) { int32_t, cutlass::layout::ContiguousLayout> ContiguousTensorView; - cutlass::MatrixLayout layouts[] = { - cutlass::MatrixLayout::kColumnMajor, - cutlass::MatrixLayout::kRowMajor + cutlass::layout::Matrix layouts[] = { + cutlass::layout::Matrix::kColumnMajor, + cutlass::layout::Matrix::kRowMajor }; cutlass::Coord<2> bounds = cutlass::make_Coord(M, N); @@ -129,7 +129,7 @@ TEST(TensorView, contiguous) { int row_stride; int col_stride; - if (layouts[i] == cutlass::MatrixLayout::kColumnMajor) { + if (layouts[i] == cutlass::layout::Matrix::kColumnMajor) { row_stride = 1; col_stride = M; ldm = col_stride; @@ -156,7 +156,7 @@ TEST(TensorView, contiguous) { } std::cout << "---------\n"; - std::cout << (layouts[i] == cutlass::MatrixLayout::kColumnMajor ? + std::cout << (layouts[i] == cutlass::layout::Matrix::kColumnMajor ? "Column-major:" : "Row-major:") << "\n\n"; std::cout << "Logical view:\n"; @@ -165,7 +165,7 @@ TEST(TensorView, contiguous) { std::cout << "Linear memory:"; for (int idx = 0; idx < view.capacity(); ++idx) { - if (!(idx % (layouts[i] == cutlass::MatrixLayout::kColumnMajor ? M : N))) { + if (!(idx % (layouts[i] == cutlass::layout::Matrix::kColumnMajor ? M : N))) { std::cout << std::endl; } std::cout << std::setw(4) << view.at(idx) << " "; diff --git a/test/unit/core/tfloat32.cu b/test/unit/core/tfloat32.cu index 32155df7c4..9b54603fee 100644 --- a/test/unit/core/tfloat32.cu +++ b/test/unit/core/tfloat32.cu @@ -51,6 +51,9 @@ TEST(tfloat32_t, host_conversion) { EXPECT_TRUE(static_cast(y) == f); } + // Try out default-ctor (zero initialization of primitive proxy type) + EXPECT_TRUE(cutlass::tfloat32_t() == 0.0_tf32); + // Try out user-defined literals EXPECT_TRUE(cutlass::tfloat32_t(7) == 7_tf32); EXPECT_TRUE(7 == static_cast(7_tf32)); diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index f536b1136f..84247e0bdc 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -60,6 +60,20 @@ cutlass_test_unit_add_executable( gemm_tf32n_tf32n_f32t_tensor_op_f32_sm80.cu gemm_tf32t_tf32t_f32t_tensor_op_f32_sm80.cu + gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu + gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu + gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu + gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu + gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu + gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu + gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu + gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu + gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu + gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu + gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu + gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu + gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu + gemm_f16t_f16n_f16t_tensor_op_f16_slicedk_sm80.cu gemm_f16n_f16t_f16t_tensor_op_f16_slicedk_sm80.cu @@ -67,6 +81,7 @@ cutlass_test_unit_add_executable( simt_sgemm_tn_sm80.cu gemm_s8t_s8n_s32t_tensor_op_s32_sm80.cu + gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu gemm_s8t_s8n_s8n_tensor_op_s32_sm80.cu gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu gemm_s4t_s4n_s32n_tensor_op_s32_sm80.cu diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..1f4d3e2933 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16n_f16t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,266 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x128x128_64x64x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 256x64x128_64x64x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 128x64x128_64x32x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f16t_tensor_op_f32, 64x64x128_32x32x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..cef53a2dc9 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16n_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu new file mode 100644 index 0000000000..849b7582e6 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f16_sparse_sm80.cu @@ -0,0 +1,265 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64> , + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x128x64_32x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x64_64x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x64_32x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x128x128_64x64x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 256x64x128_64x64x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 128x64x128_64x32x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f16t_tensor_op_f16, 64x64x128_32x32x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..8ae6464f27 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,266 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if (CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x128x128_64x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 256x64x128_64x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 128x64x128_64x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16n_f16t_f32t_tensor_op_f32, 64x64x128_32x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::ColumnMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED + diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu new file mode 100644 index 0000000000..ffba9c0dac --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f16_sparse_sm80.cu @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x256x64_64x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x128x64_32x64x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x64_64x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x64_32x32x64) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x128x128_64x64x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 256x64x128_64x64x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 128x64x128_64x32x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f16t_tensor_op_f16, 64x64x128_32x32x128) { + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..b62d99f78a --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,265 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x128x128_64x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 256x64x128_64x64x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 128x64x128_64x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16n_f32t_tensor_op_f32, 64x64x128_32x32x128) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED diff --git a/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..0497e61945 --- /dev/null +++ b/test/unit/gemm/device/gemm_f16t_f16t_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,193 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x128x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 256x64x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x256x64_64x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x128x64_32x64x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 128x64x64_64x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f16t_f16t_f32t_tensor_op_f32, 64x64x64_32x32x64) { + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::half_t, cutlass::layout::RowMajor, cutlass::half_t, + cutlass::layout::RowMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..869b59b51d --- /dev/null +++ b/test/unit/gemm/device/gemm_f32n_f32n_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,423 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..fda4371705 --- /dev/null +++ b/test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,423 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32n_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..7c2b6c6e38 --- /dev/null +++ b/test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,422 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 256x64x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 128x64x64_64x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32n_f32t_tensor_op_f32, 64x64x64_32x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu b/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu new file mode 100644 index 0000000000..eec3ca4cdb --- /dev/null +++ b/test/unit/gemm/device/gemm_f32t_f32t_f32t_tensor_op_f32_sparse_sm80.cu @@ -0,0 +1,423 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x256x32_64x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 10 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 256x64x64_64x64x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 128x64x64_64x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_f32t_f32t_f32t_tensor_op_f32, 64x64x64_32x32x64) { + + using ElementOutput = float; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::SparseGemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 6 + >; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu new file mode 100644 index 0000000000..aaf618267e --- /dev/null +++ b/test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu @@ -0,0 +1,261 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x256x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x128x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 256x64x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x256x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x128x256_32x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 256>, + cutlass::gemm::GemmShape<32, 64, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x256_64x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 256>, + cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x128x512_64x64x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 512>, + cutlass::gemm::GemmShape<64, 64, 512>, + cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 128x64x512_64x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 512>, + cutlass::gemm::GemmShape<64, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s4t_s4n_s32t_tensor_op_s32, 64x64x512_32x32x512) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + cutlass::int4b_t, cutlass::layout::RowMajor, cutlass::int4b_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 512>, + cutlass::gemm::GemmShape<32, 32, 512>, cutlass::gemm::GemmShape<16, 8, 128>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu new file mode 100644 index 0000000000..9e1076a833 --- /dev/null +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32n_tensor_op_s32_sm80.cu @@ -0,0 +1,355 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x64x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x256x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x128_32x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x128_64x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x128_32x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x256x64_64x64x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x128x64_64x64x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x128x64_64x64x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 256x64x64_64x64x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x256x64_64x64x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x128x64_32x64x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<32, 64, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 128x64x64_64x32x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +TEST(SM80_Device_Gemm_s8t_s8n_s32n_tensor_op_s32, 64x64x64_32x32x64) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::Gemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::ColumnMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>, cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllGemm()); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu new file mode 100644 index 0000000000..5b9b1d7d95 --- /dev/null +++ b/test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu @@ -0,0 +1,263 @@ +/************************************************************************************************** + Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + + Redistribution and use in source and binary forms, with or without modification, are permitted + provided that the following conditions are met: + * Redistributions of source code must retain the above copyright notice, this list of + conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright notice, this list of + conditions and the following disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + to endorse or promote products derived from this software without specific prior written + permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_sparse.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x256x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x128x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 256x64x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<256, 64, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x256x128_64x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 256, 128>, + cutlass::gemm::GemmShape<64, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x128x128_32x64x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 128>, + cutlass::gemm::GemmShape<32, 64, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x128_64x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 128>, + cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x128_32x32x128) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 128>, + cutlass::gemm::GemmShape<32, 32, 128>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 10>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x128x256_64x64x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 256>, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 128x64x256_64x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 256>, + cutlass::gemm::GemmShape<64, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + +TEST(SM80_Device_Sparse_Gemm_s8t_s8n_s32t_tensor_op_s32, 64x64x256_32x32x256) { + using ElementOutput = int32_t; + using ElementAccumulator = int32_t; + using ElementCompute = int32_t; + + using Gemm = cutlass::gemm::device::SparseGemm< + int8_t, cutlass::layout::RowMajor, int8_t, + cutlass::layout::ColumnMajor, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 256>, + cutlass::gemm::GemmShape<32, 32, 256>, cutlass::gemm::GemmShape<16, 8, 64>, + cutlass::epilogue::thread::LinearCombinationClamp< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementCompute>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 6>; + + EXPECT_TRUE(test::gemm::device::TestAllSparseGemm()); +} + + +//////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + + diff --git a/test/unit/gemm/device/testbed_sparse.h b/test/unit/gemm/device/testbed_sparse.h new file mode 100644 index 0000000000..d1d57b893c --- /dev/null +++ b/test/unit/gemm/device/testbed_sparse.h @@ -0,0 +1,440 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + + Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SparseTestbed { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + static int const kSparse = Gemm::GemmKernel::kSparse; + static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; + static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; + static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; + + using ElementE = typename Gemm::GemmKernel::ElementE; + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_E; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + SparseTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), + init_B(init_B_), + init_C(init_C_), + init_E(init_E_), + seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // TODO: Implement the rest + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); + tensor_A_uncompressed.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + tensor_E.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + tensor_E_reordered.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, kMetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + // TODO: Implement the rest + EXPECT_TRUE(false); + } + + cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_E_reordered.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nE =\n" << tensor_E.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), problem_size.m(), problem_size.k()); + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + tensor_E_reordered.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllSparseGemm() { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = {Gemm::ThreadblockShape::kK, + Gemm::ThreadblockShape::kK * (Gemm::kStages + 1)}; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + SparseTestbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/test/unit/gemm/threadblock/CMakeLists.txt b/test/unit/gemm/threadblock/CMakeLists.txt index f208b9ef17..f4f074fe99 100644 --- a/test/unit/gemm/threadblock/CMakeLists.txt +++ b/test/unit/gemm/threadblock/CMakeLists.txt @@ -22,6 +22,9 @@ cutlass_test_unit_add_executable( cutlass_test_unit_gemm_threadblock + mma_multistage.cu + mma_multistage_sparse.cu + mma_pipelined_sm80.cu mma_pipelined_wmma_sm70.cu mma_pipelined_wmma_sm75.cu mma_singlestage_wmma_sm70.cu @@ -29,5 +32,7 @@ cutlass_test_unit_add_executable( mma_pipelined_sm70.cu mma_pipelined_sm75.cu mma_pipelined_simt.cu + mma_planar_complex_sm80.cu + ) diff --git a/test/unit/gemm/threadblock/mma_multistage.cu b/test/unit/gemm/threadblock/mma_multistage.cu new file mode 100644 index 0000000000..e4a030d6fa --- /dev/null +++ b/test/unit/gemm/threadblock/mma_multistage.cu @@ -0,0 +1,3827 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for threadblock-level GEMM +*/ + +#include "mma_multistage_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x64x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x64x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x128x64_32x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x128x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_256x256x384_128x128x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_512x256x384_256x128x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x64x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x64x32_64x32x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x128x32_32x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_512x256x768_256x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x64x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x64x32_64x32x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x128x32_32x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x128x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_256x256x192_128x128x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_512x256x384_256x128x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x64x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x64x16_64x32x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x128x16_32x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x128x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_256x256x192_128x128x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_512x256x384_256x128x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x64_32x32x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x128x64_32x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x128x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x384_128x128x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x768_256x128x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x32_32x32x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x32_64x32x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x128x32_32x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x768_256x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x32_32x32x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x32_64x32x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x128x32_32x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x128x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x192_128x128x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x192_256x128x32_64x64x32_16x8x8_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x16_32x32x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x16_64x32x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x128x16_32x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x128x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x192_128x128x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x192_256x128x16_64x64x16_16x8x8_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x128_64x64x128_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x128_32x32x128_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x128_64x32x128_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x128x128_32x64x128_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x128x128_64x64x128_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x768_128x128x128_64x64x128_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x768_256x128x128_64x64x128_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x64_64x64x64_16x8x32_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x64_32x32x64_16x8x32_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x64_64x32x64_16x8x32_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x128x64_32x64x64_16x8x32_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x128x64_64x64x64_16x8x32_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x256_64x64x256_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x256_32x32x256_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x256_64x32x256_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x256x256_32x64x256_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x256x256_64x64x256_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x1536_128x256x256_64x64x256_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x1536_256x256x256_64x64x256_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x128_64x64x128_16x8x64_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x128_32x32x128_16x8x64_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x128_64x32x128_16x8x64_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x256x128_32x64x128_16x8x64_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x256x128_64x64x128_16x8x64_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x1536_128x256x128_64x64x128_16x8x64_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x1536_256x256x128_64x64x128_16x8x64_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x1024_64x64x1024_16x8x256_3stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 1024>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x1024_32x32x1024_16x8x256_3stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 1024>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x1024_64x32x1024_16x8x256_3stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 1024>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x1024x1024_32x64x1024_16x8x256_3stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 1024>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x1024x1024_64x64x1024_16x8x256_3stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 1024>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x6144_128x1024x1024_64x64x1024_16x8x256_3stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 6144); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 1024>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x6144_256x1024x1024_64x64x1024_16x8x256_3stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 6144); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 1024>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 1024>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x512_64x64x512_16x8x256_4stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x64x512_32x32x512_16x8x256_4stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x64x512_64x32x512_16x8x256_4stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_64x128x512_32x64x512_16x8x256_4stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + tensor_op_128x128x512_64x64x512_16x8x256_4stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 4096); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x6144_128x128x512_64x64x512_16x8x256_4stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 6144); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x6144_256x128x512_64x64x512_16x8x256_4stage) { + using ElementA = cutlass::uint1b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::uint1b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 6144); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_64x64x16_32x64x16_8x8x4_3stage) { + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 16); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 2, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + tensor_op_128x128x16_32x64x16_8x8x4_3stage) { + using ElementA = double; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = double; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + tensor_op_64x128x64_32x64x64_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; + using ElementB = int8_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<32>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + tensor_op_128x128x64_64x64x64_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; + using ElementB = int8_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<32>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + multicta_256x256x384_128x128x64_64x64x64_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; + using ElementB = int8_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<32>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + multicta_512x256x384_256x128x64_64x64x64_16x8x32_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<32>; + using ElementB = int8_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<32>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + tensor_op_64x128x128_32x64x128_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<64>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + tensor_op_128x128x128_64x64x128_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<64>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + multicta_256x256x768_128x128x128_64x64x128_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<64>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_interleaved, + multicta_512x256x1536_256x128x128_64x64x128_16x8x64_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::ColumnMajorInterleaved<64>; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::RowMajorInterleaved<64>; + using ElementC = int; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise_f64, + tensor_op_32x32x16_16x16x16_8x8x4_4stage) { + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(32, 32, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +TEST(SM80_gemm_threadblock_crosswise_f64, + tensor_op_64x64x16_32x32x16_8x8x4_4stage) { + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +TEST(SM80_gemm_threadblock_crosswise_f64, + tensor_op_64x128x16_32x64x16_8x8x4_4stage) { + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +TEST(SM80_gemm_threadblock_crosswise_f64, + tensor_op_128x64x16_64x32x16_8x8x4_4stage) { + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +TEST(SM80_gemm_threadblock_crosswise_f64, + tensor_op_128x128x16_32x64x16_8x8x4_3stage) { + using ElementA = double; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = double; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = double; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 128); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// +#endif diff --git a/test/unit/gemm/threadblock/mma_multistage_sparse.cu b/test/unit/gemm/threadblock/mma_multistage_sparse.cu new file mode 100644 index 0000000000..13eb180e05 --- /dev/null +++ b/test/unit/gemm/threadblock/mma_multistage_sparse.cu @@ -0,0 +1,2697 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit tests for threadblock-level GEMM +*/ + +#include "mma_multistage_sparse_testbed.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x64x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x64x64_32x32x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x64x64_64x32x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x128x64_32x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x128x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x64_32x32x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x64_64x32x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x64_32x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x768_128x128x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_512x256x768_256x128x64_64x64x64_16x8x32_4stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x64x128_64x64x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x64x128_64x32x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x128x128_32x64x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x128x128_64x32x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + multicta_256x256x768_128x128x128_64x32x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x128_64x64x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x128_32x32x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x128_64x32x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x128_32x64x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x128_64x32x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 512); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x768_128x128x128_64x32x128_16x8x32_3stage) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 768); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x64x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x64x32_32x32x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x64x32_64x32x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x128x32_32x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + multicta_512x256x384_256x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x32_32x32x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x32_64x32x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x32_32x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x384_128x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_512x256x384_256x128x32_64x64x32_16x8x16_4stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x64x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x64x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_64x128x64_32x64x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + tensor_op_128x128x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_congruous, + multicta_256x256x384_128x128x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x64_64x64x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x64_32x32x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x64_32x64x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 256); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x384_128x128x64_64x32x64_16x8x16_3stage) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 384); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x128_64x64x128_16x8x64_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x128_32x32x128_16x8x64_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x128_64x32x128_16x8x64_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x128_32x64x128_16x8x64_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x128_64x64x128_16x8x64_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x1536_128x128x128_64x64x128_16x8x64_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_512x256x1536_256x128x128_64x64x128_16x8x64_4stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 128>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x256_64x64x256_16x8x64_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x256_32x32x256_16x8x64_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x256_64x32x256_16x8x64_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x256_32x64x256_16x8x64_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x256_64x32x256_16x8x64_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 1024); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x1536_128x128x256_64x32x256_16x8x64_3stage) { + using ElementA = int8_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = int8_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 1536); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x256_64x64x256_16x8x128_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x256_32x32x256_16x8x128_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x256_64x32x256_16x8x128_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x256_32x64x256_16x8x128_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x256_64x64x256_16x8x128_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x3072_128x128x256_64x64x256_16x8x128_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 3072); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_512x256x3072_256x128x256_64x64x256_16x8x128_4stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 3072); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 256>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 4; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x512_64x64x512_16x8x128_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x64x512_32x32x512_16x8x128_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x64x512_64x32x512_16x8x128_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_64x128x512_32x64x512_16x8x128_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + tensor_op_128x128x512_64x32x512_16x8x128_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 2048); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(1, 1); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_sparse_gemm_threadblock_crosswise, + multicta_256x256x3072_128x128x512_64x32x512_16x8x128_3stage) { + using ElementA = cutlass::int4b_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::int4b_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = int32_t; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 3072); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 512>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + + float alpha = 1.f; + float beta = 0.0f; + int const Stages = 3; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultSparseMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp, + Stages>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::SparseTestbed( + problem_size.m(), problem_size.n(), problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + diff --git a/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h new file mode 100644 index 0000000000..7036e26d97 --- /dev/null +++ b/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -0,0 +1,398 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, int ldc, + typename Mma::IteratorE::Params params_E, + typename Mma::IteratorE::TensorRef ref_E) { + // Shared storage needed by threadblock-scoped matrix multiply- + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k() / Mma::kSparse}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorE iterator_E( + params_E, ref_E.data(), + {problem_size.m(), + problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, + tb_thread_id, tb_offset_E); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct SparseTestbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ElementE = typename MmaCore::ElementE; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using ThreadMapE = typename MmaCore::IteratorThreadMapE; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + using AccessTypeE = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + static cutlass::arch::CacheOperation::Kind const CacheOpE = + MmaCore::kCacheOpE; + + static int const Sparse = MmaCore::kSparse; + static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; + static int const MaxID2 = MmaCore::kMaxID2; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename MmaCore::GmemLayoutE; + + static int const ElementsPerElementE = MmaCore::kElementsPerElementE; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define iterators over tiles from the E operand + using IteratorE = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, + typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_A_uncompressed; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_E; + cutlass::HostTensor matrix_E_reordered; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); + matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + matrix_E_reordered.reset( + cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + matrix_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(matrix_E.host_view(), + (ElementE)(content)); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / Sparse / ElementsPerElementE}); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + matrix_E_reordered.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + typename IteratorE::Params params_E(matrix_E_reordered.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributeMaxDynamicSharedMemorySize error: " + << cudaGetErrorString(result); + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributePreferredSharedMemoryCarveout error: " + << cudaGetErrorString(result); + } + + test::gemm::threadblock::kernel_multistage_mma_sparse + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), params_E, + matrix_E_reordered.device_ref()); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), + matrix_E.host_ref(), problem_size.m(), + problem_size.k()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm(problem_size, ElementC(alpha), + matrix_A_uncompressed.host_view(), matrix_B.host_view(), + ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed) + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "E:\n" << matrix_E.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/test/unit/gemm/threadblock/mma_multistage_testbed.h b/test/unit/gemm/threadblock/mma_multistage_testbed.h new file mode 100644 index 0000000000..3870dd22fb --- /dev/null +++ b/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, int ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributeMaxDynamicSharedMemorySize error: " + << cudaGetErrorString(result); + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributePreferredSharedMemoryCarveout error: " + << cudaGetErrorString(result); + } + + test::gemm::threadblock::kernel_multistage_mma + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed) + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/test/unit/gemm/threadblock/mma_pipelined_sm80.cu b/test/unit/gemm/threadblock/mma_pipelined_sm80.cu new file mode 100644 index 0000000000..14dd68e72d --- /dev/null +++ b/test/unit/gemm/threadblock/mma_pipelined_sm80.cu @@ -0,0 +1,563 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for threadblock-level GEMM +*/ + +#include "mma_pipelined_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, tensor_op_64x64x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, tensor_op_128x64x16_64x32x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, tensor_op_64x128x16_32x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, tensor_op_128x128x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_256x256x96_128x128x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 96); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_congruous, + multicta_512x256x192_256x128x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 1, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_32x32x16_16x16x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(32, 32, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_32x64x16_16x32x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(32, 64, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<32, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x32x16_32x16x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 32, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, + cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x64x16_32x32x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x64x16_64x32x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 64, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_64x128x16_32x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(64, 128, 64); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, tensor_op_128x128x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(128, 128, 48); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; + + dim3 grid(1, 1); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_256x256x48_128x128x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(256, 256, 48); + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; + + dim3 grid(2, 2); + dim3 block(32, 4, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_crosswise, + multicta_512x256x192_256x128x16_64x64x16_16x8x4) { + using ElementA = cutlass::tfloat32_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::tfloat32_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + + cutlass::gemm::GemmCoord problem_size(512, 256, 192); + + using ThreadblockShape = cutlass::gemm::GemmShape<256, 128, 16>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; + + float alpha = 1.f; + float beta = 0.0f; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementC, LayoutC, cutlass::arch::OpClassTensorOp>; + + dim3 grid(2, 2); + dim3 block(32, 8, 1); + + test::gemm::threadblock::Testbed(problem_size.m(), problem_size.n(), + problem_size.k(), alpha, beta) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + + diff --git a/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu b/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu new file mode 100644 index 0000000000..ebcf0a355e --- /dev/null +++ b/test/unit/gemm/threadblock/mma_planar_complex_sm80.cu @@ -0,0 +1,73 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for threadblock-level GEMM +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/threadblock/default_mma_planar_complex_multistage.h" + +#include "mma_planar_complex_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_gemm_threadblock_planar_complex_congruous, tensor_op_64x64x32_64x64x32_16x8x16_3stage) { + + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::ColumnMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::RowMajor; + using ElementC = float; + using LayoutC = cutlass::layout::RowMajor; + + cutlass::gemm::GemmCoord problem_size(64, 64, 8); + + using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + int const Stages = 3; + + // Define the MmaCore components + using Mma = typename cutlass::gemm::threadblock::DefaultMmaPlanarComplexMultistage< + ElementA, LayoutA, 8, + ElementB, LayoutB, 8, + ElementC, LayoutC, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, + Stages>::ThreadblockMma; + + dim3 grid(1, 1); + dim3 block(32, Mma::WarpCount::kCount, 1); + + test::gemm::threadblock::TestbedPlanarComplex(problem_size.m(), problem_size.n(), + problem_size.k()) + .run(grid, block); +} + +//////////////////////////////////////////////////////////////////////////////// +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/warp/gemm_complex_sm80.cu b/test/unit/gemm/warp/gemm_complex_sm80.cu index 3fcd70c8d0..99effe4004 100644 --- a/test/unit/gemm/warp/gemm_complex_sm80.cu +++ b/test/unit/gemm/warp/gemm_complex_sm80.cu @@ -629,7 +629,66 @@ TEST(SM80_warp_gemm_complex_tensor_op_f32, 64x32x8_16x8x8_tn) { .run(); } -/////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x8_8x8x4_tn) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_complex_tensor_op_f64, 32x32x8_8x8x4_nt) { + + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using Element = cutlass::complex; + using ElementC = cutlass::complex; + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaComplexTensorOp< + Shape, + InstructionShape, + Element, + LayoutA, + Element, + LayoutB, + ElementC, + cutlass::layout::RowMajor + >::Type; + + test::gemm::warp::TransformedTestbedComplex< + MmaTensorOp, cutlass::gemm::GemmShape<32, 32, 8> >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////// + #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/warp/gemm_sm80.cu b/test/unit/gemm/warp/gemm_sm80.cu index 377e760c6b..0f736b1355 100644 --- a/test/unit/gemm/warp/gemm_sm80.cu +++ b/test/unit/gemm/warp/gemm_sm80.cu @@ -1778,5 +1778,81 @@ TEST(SM80_warp_gemm_tensor_op_interleaved, 128x128x128_64x64x128_16x8x64) { //////////////////////////////////////////////////////////////////////////////// +TEST(SM80_warp_gemm_tensor_op_canonical_f64_row_col, 32x32x8_64x32x8_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_canonical_f64_col_row, 32x32x8_64x32x8_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 4>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Element = double; + using ElementC = double; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_canonical_tf32_row_col, 32x32x8_64x32x8_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_tensor_op_canonical_tf32_col_row, 32x32x8_64x32x8_8x8x4) { + using Shape = cutlass::gemm::GemmShape<32, 32, 8>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::Testbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + #endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/warp/gemm_sparse_sm80.cu b/test/unit/gemm/warp/gemm_sparse_sm80.cu new file mode 100644 index 0000000000..8df0846076 --- /dev/null +++ b/test/unit/gemm/warp/gemm_sparse_sm80.cu @@ -0,0 +1,1101 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for thread-level GEMM +*/ + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" + +#include "cutlass/gemm/warp/default_mma_sparse_tensor_op.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_64x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_64x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x128x64_32x16x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 128x64x128_64x32x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x128x128_32x64x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x64x128_32x32x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_f16, 64x32x128_32x16x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_64x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_64x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_32x64x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x128x64_32x32x64_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 128x64x128_64x32x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x128x128_32x64x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_f16, 64x64x128_32x32x128_16x8x32) { + using Shape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + using Element = cutlass::half_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 64>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x64x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_64x32x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x64x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 64, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x32x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 32, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x128x128_32x16x128_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 16, 128>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 64>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 128x64x256_64x32x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x128x256_32x64x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x64x256_32x32x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s8, 64x32x256_32x16x256_16x8x64) { + using Shape = cutlass::gemm::GemmShape<32, 16, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 64>; + using Element = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_64x64x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_64x32x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x64x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x32x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x128x32_32x16x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 16, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 16>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 128x64x256_64x32x256_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x128x64_32x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x64x64_32x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_tf32, 64x32x64_32x16x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 16, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_64x64x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_64x32x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_32x64x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x128x32_32x32x32_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 128x64x64_64x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<64, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 64x128x64_32x64x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 64, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_congruous_tf32, 64x64x64_32x32x64_16x8x16) { + using Shape = cutlass::gemm::GemmShape<32, 32, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using Element = cutlass::tfloat32_t; + using ElementC = float; + using LayoutA = cutlass::layout::ColumnMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + using LayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous< + cutlass::sizeof_bits::value, 32>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x64x256_16x8x128) { + using Shape = cutlass::gemm::GemmShape<64, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_64x32x256_16x8x128) { + using Shape = cutlass::gemm::GemmShape<64, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x64x256_16x8x128) { + using Shape = cutlass::gemm::GemmShape<32, 64, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x32x256_16x8x128) { + using Shape = cutlass::gemm::GemmShape<32, 32, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x128x256_32x16x256_16x8x128) { + using Shape = cutlass::gemm::GemmShape<32, 16, 256>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 128>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 128x64x512_64x32x512_16x8x128) { + using Shape = cutlass::gemm::GemmShape<64, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x128x512_32x64x512_16x8x128) { + using Shape = cutlass::gemm::GemmShape<32, 64, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x64x512_32x32x512_16x8x128) { + using Shape = cutlass::gemm::GemmShape<32, 32, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_warp_gemm_sparse_tensor_op_crosswise_s4, 64x32x512_32x16x512_16x8x128) { + using Shape = cutlass::gemm::GemmShape<32, 16, 512>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 128>; + using Element = cutlass::int4b_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise< + cutlass::sizeof_bits::value, 256>; + + using MmaTensorOp = typename cutlass::gemm::warp::DefaultSparseMmaTensorOp< + Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC, + cutlass::layout::RowMajor>::Type; + + test::gemm::warp::SparseTestbed >() + .run(); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_SUPPORTED) diff --git a/test/unit/gemm/warp/testbed.h b/test/unit/gemm/warp/testbed.h index 8a565fd9fd..c0c98d80df 100644 --- a/test/unit/gemm/warp/testbed.h +++ b/test/unit/gemm/warp/testbed.h @@ -31,6 +31,7 @@ #include "cutlass/cutlass.h" #include "cutlass/aligned_buffer.h" #include "cutlass/subbyte_reference.h" +#include "cutlass/platform/platform.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/tensor_view_io.h" @@ -41,6 +42,8 @@ #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_copy.h" #include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" namespace test { namespace gemm { @@ -996,7 +999,359 @@ struct TransformedTestbedComplex { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Test kernel +template +__global__ void sparse_kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + typename Mma::ElementE const *input_E, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer + smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementE, ThreadblockShape::kM * ThreadblockShape::kK / + Mma::kSparse / Mma::kElementsPerElementE> + smem_buffer_E; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (int i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (int i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + + typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); + #pragma unroll 1 + for (int i = 0; i < smem_buffer_E.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_E, i) = + cutlass::ReferenceFactory::type>::get(input_E, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + using FragmentE = typename Mma::FragmentE; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed( + {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); + typename Mma::LayoutB layout_B = + Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + typename Mma::LayoutE layout_E = + Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, + Mma::Shape::kK / Mma::kSparse / + Mma::kElementsPerElementE / Mma::kInterleaved}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::LaneId()); + + typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + FragmentE frag_E; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + iter_E.load(frag_E); + + ++iter_A; + ++iter_B; + ++iter_E; + + mma(accum, frag_A, frag_B, accum, frag_E); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct SparseTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + static int const Sparse = Mma::kSparse; + static int const MetaSizeInBits = Mma::kMetaSizeInBits; + static int const MaxID2 = Mma::kMaxID2; + static int const Interleaved = Mma::kInterleaved; + + using ElementE = typename Mma::ElementE; + + static int const ElementsPerElementE = Mma::kElementsPerElementE; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = + cutlass::layout::ColumnMajorInterleaved; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed() { + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, + ThreadblockShape::kK / Sparse)); + tensor_A_uncompressed.reset( + cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_E.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + tensor_E_reordered.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + // TODO: Implement the rest + return false; + } + + cutlass::reorder_meta( + tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_E_reordered.sync_device(); + + // launch kernel + sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_E_reordered.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), Shape::kM, Shape::kK); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace warp } // namespace gemm } // namespace test - diff --git a/test/unit/layout/matrix.cu b/test/unit/layout/matrix.cu index 2f8d0ea2be..e463f0974e 100644 --- a/test/unit/layout/matrix.cu +++ b/test/unit/layout/matrix.cu @@ -122,8 +122,9 @@ TEST(Layout_Matrix, general_matrix) { cutlass::layout::GeneralMatrix::TensorCoord extent = {M, N}; - cutlass::layout::GeneralMatrix layout = - cutlass::layout::GeneralMatrix::packed(extent, cutlass::MatrixLayout::kColumnMajor, interleave); + cutlass::layout::GeneralMatrix layout = + cutlass::layout::GeneralMatrix::packed( + extent, cutlass::layout::Matrix::kColumnMajor, interleave); cutlass::HostTensor tensor(extent); diff --git a/test/unit/util/CMakeLists.txt b/test/unit/util/CMakeLists.txt new file mode 100644 index 0000000000..7f103cbf3c --- /dev/null +++ b/test/unit/util/CMakeLists.txt @@ -0,0 +1,26 @@ +# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without modification, are permitted +# provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright notice, this list of +# conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright notice, this list of +# conditions and the following disclaimer in the documentation and/or other materials +# provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific prior written +# permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR +# IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +# FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; +# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, +# STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_util + tensor_reduce.cu + ) diff --git a/test/unit/util/complex.cu b/test/unit/util/complex.cu deleted file mode 100644 index 319bbb2aa4..0000000000 --- a/test/unit/util/complex.cu +++ /dev/null @@ -1,102 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#include - -#include "cutlass_unit_test.h" -#include "cutlass/util/complex.h" -#include "tools/util/half.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace test { - - /// Thorough testing for basic complex math operators. Uses std::complex as a reference. - template - struct ComplexOperators { - ComplexOperators() { - for (int ar = -N; ar <= N; ++ar) { - for (int ai = -N; ai <= N; ++ai) { - for (int br = -N; br <= N; ++br) { - for (int bi = -N; bi <= N; ++bi) { - - cutlass::platform::complex Ae(T(ar) / T(M), T(ai) / T(M)); - cutlass::platform::complex Be(T(br) / T(M), T(bi) / T(M)); - - std::complex Ar(T(ar) / T(M), T(ai) / T(M)); - std::complex Br(T(br) / T(M), T(bi) / T(M)); - - cutlass::platform::complex add_e = Ae + Be; - cutlass::platform::complex sub_e = Ae - Be; - cutlass::platform::complex mul_e = Ae * Be; - - std::complex add_r = (Ar + Br); - std::complex sub_r = (Ar - Br); - std::complex mul_r = (Ar * Br); - - EXPECT_EQ(real(add_e), real(add_r)); - EXPECT_EQ(imag(add_e), imag(add_r)); - - EXPECT_EQ(real(sub_e), real(sub_r)); - EXPECT_EQ(imag(sub_e), imag(sub_r)); - - EXPECT_EQ(real(mul_e), real(mul_r)); - EXPECT_EQ(imag(mul_e), imag(mul_r)); - - if (!(br == 0 && bi == 0)) { - - cutlass::platform::complex div_e = Ae * Be; - std::complex div_r = Ar * Br; - - EXPECT_EQ(real(div_e), real(div_r)); - EXPECT_EQ(imag(div_e), imag(div_r)); - } - } - } - } - } - } - }; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(Complex, host_float) { - test::ComplexOperators test; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -TEST(Complex, host_double) { - test::ComplexOperators test; -} - -/////////////////////////////////////////////////////////////////////////////////////// - -TEST(Complex, host_half) { - // Fewer test cases since half_t is emulated - test::ComplexOperators test; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/util/tensor_reduce.cu b/test/unit/util/tensor_reduce.cu new file mode 100644 index 0000000000..5a1afc7f39 --- /dev/null +++ b/test/unit/util/tensor_reduce.cu @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include + +#include "../common/cutlass_unit_test.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/util/reference/device/tensor_reduce.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_tensor.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(TensorReduce, norm_rowmajor_f32) { + + int const kM = 129; + int const kN = 91; + + cutlass::HostTensor tensor({kM, kN}); + + for (int m = 0; m < kM; ++m) { + for (int n = 0; n < kN; ++n) { + + float x = float(((m * kN + m + 7) % 8) - 4); + + tensor.at({m, n}) = x; + } + } + + tensor.sync_device(); + + double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); + double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); + + EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001); +} + +TEST(TensorReduce, norm_nhwc_f32) { + + int const kN = 19; + int const kH = 18; + int const kW = 17; + int const kC = 16; + + cutlass::HostTensor tensor({kN, kH, kW, kC}); + + int idx = 0; + + double computed_norm = double(); + + for (int n = 0; n < kN; ++n) { + for (int h = 0; h < kH; ++h) { + for (int w = 0; w < kW; ++w) { + for (int c = 0; c < kC; ++c, ++idx) { + + float x = float(((idx + 7) % 8) - 4); + + computed_norm += double(x) * double(x); + + tensor.at({n, h, w, c}) = x; + } + } + } + } + + computed_norm = std::sqrt(computed_norm); + + tensor.sync_device(); + + double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); + double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); + + EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001 && std::abs(computed_norm - host_norm) < 0.001) + << "computed norm: " << computed_norm << "\n" + << " host norm: " << host_norm << "\n" + << "device norm: " << device_norm << "\n"; +} + +TEST(TensorReduce, norm_nhwc_f16) { + + int const kN = 69; + int const kH = 68; + int const kW = 67; + int const kC = 66; + + cutlass::HostTensor tensor({kN, kH, kW, kC}); + + int idx = 0; + + double computed_norm = double(); + + for (int n = 0; n < kN; ++n) { + for (int h = 0; h < kH; ++h) { + for (int w = 0; w < kW; ++w) { + for (int c = 0; c < kC; ++c, ++idx) { + + float x = float(((idx + 7) % 8) - 4); + computed_norm += double(x) * double(x); + + tensor.at({n, h, w, c}) = cutlass::half_t(x); + } + } + } + } + + computed_norm = std::sqrt(computed_norm); + + tensor.sync_device(); + + double device_norm = cutlass::reference::device::TensorNorm(tensor.device_view(), double()); + double host_norm = cutlass::reference::host::TensorNorm(tensor.host_view(), double()); + + EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001 && std::abs(computed_norm - host_norm) < 0.001) + << "computed norm: " << computed_norm << "\n" + << " host norm: " << host_norm << "\n" + << "device norm: " << device_norm << "\n"; +} + +TEST(TensorReduce, norm_diff_nhwc_f32) { + + int const kN = 59; + int const kH = 24; + int const kW = 57; + int const kC = 78; + + using Layout = cutlass::layout::TensorNHWC; + + cutlass::HostTensor tensor_A({kN, kH, kW, kC}); + cutlass::HostTensor tensor_B({kN, kH, kW, kC}); + + + int idx = 0; + + double sum_sq_diff = 0; + + for (int n = 0; n < kN; ++n) { + for (int h = 0; h < kH; ++h) { + for (int w = 0; w < kW; ++w) { + for (int c = 0; c < kC; ++c, ++idx) { + + float a = float(((idx * 5 + 7) % 8) - 4); + float b = float(((idx * 3 + 7) % 8) - 4); + + sum_sq_diff += double(a - b) * double(a - b); + + tensor_A.at({n, h, w, c}) = a; + tensor_B.at({n, h, w, c}) = b; + } + } + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + + double device_norm = cutlass::reference::device::TensorNormDiff( + tensor_A.device_view(), tensor_B.device_view(), double()); + + double host_norm = std::sqrt(sum_sq_diff); + + EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001f) + << " host norm: " << host_norm << "\n" + << "device norm: " << device_norm; +} + + +TEST(TensorReduce, norm_diff_nhwc_f16) { + + int const kN = 59; + int const kH = 24; + int const kW = 57; + int const kC = 78; + + using Layout = cutlass::layout::TensorNHWC; + + cutlass::HostTensor tensor_A({kN, kH, kW, kC}); + cutlass::HostTensor tensor_B({kN, kH, kW, kC}); + + int idx = 0; + + double sum_sq_diff = 0; + + for (int n = 0; n < kN; ++n) { + for (int h = 0; h < kH; ++h) { + for (int w = 0; w < kW; ++w) { + for (int c = 0; c < kC; ++c, ++idx) { + + float a = float(((idx * 5 + 7) % 8) - 4); + float b = float(((idx * 3 + 7) % 8) - 4); + + sum_sq_diff += double(a - b) * double(a - b); + + tensor_A.at({n, h, w, c}) = cutlass::half_t(a); + tensor_B.at({n, h, w, c}) = cutlass::half_t(b); + } + } + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + + double device_norm = cutlass::reference::device::TensorNormDiff( + tensor_A.device_view(), tensor_B.device_view(), double()); + + double host_norm = std::sqrt(sum_sq_diff); + + EXPECT_TRUE(std::abs(host_norm - device_norm) < 0.001f) + << " host norm: " << host_norm << "\n" + << "device norm: " << device_norm; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 5c140a9a76..3ca637b2db 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -21,6 +21,11 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. add_subdirectory(util) -add_subdirectory(library) -add_subdirectory(profiler) + +if (CUTLASS_ENABLE_LIBRARY) + add_subdirectory(library) +endif() +if (CUTLASS_ENABLE_PROFILER) + add_subdirectory(profiler) +endif() diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index 37bb89901e..294cd98f01 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -60,6 +60,9 @@ cutlass_add_library( src/singleton.cu src/util.cu + src/reference/gemm.cu + src/reference/initialize_reference_operations.cu + ) file(GLOB_RECURSE GENERATOR_PYTHON_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/scripts/*.py) @@ -80,6 +83,7 @@ execute_process( --generator-target library --architectures "${CUTLASS_NVCC_ARCHS_ENABLED}" --kernels "${CUTLASS_LIBRARY_KERNELS}" + --ignore-kernels "${CUTLASS_LIBRARY_IGNORE_KERNELS}" --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT diff --git a/tools/library/include/cutlass/library/arch_mappings.h b/tools/library/include/cutlass/library/arch_mappings.h new file mode 100644 index 0000000000..787e471280 --- /dev/null +++ b/tools/library/include/cutlass/library/arch_mappings.h @@ -0,0 +1,99 @@ +/*************************************************************************************************** + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/arch/arch.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ArchMap; + +template <> struct ArchMap { + static int const kMin = 50; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 60; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 61; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 75; +}; + +template struct ArchMap { + static int const kMin = 75; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 80; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 86; + static int const kMax = 1024; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/include/cutlass/library/library.h b/tools/library/include/cutlass/library/library.h index d093b6118c..f692437199 100644 --- a/tools/library/include/cutlass/library/library.h +++ b/tools/library/include/cutlass/library/library.h @@ -65,6 +65,8 @@ enum class LayoutTypeID { kUnknown, kColumnMajor, kRowMajor, + kColumnMajorInterleavedK2, + kRowMajorInterleavedK2, kColumnMajorInterleavedK4, kRowMajorInterleavedK4, kColumnMajorInterleavedK16, @@ -74,7 +76,9 @@ enum class LayoutTypeID { kColumnMajorInterleavedK64, kRowMajorInterleavedK64, kTensorNCHW, + kTensorNCDHW, kTensorNHWC, + kTensorNDHWC, kInvalid }; @@ -83,11 +87,13 @@ enum class NumericTypeID { kUnknown, kVoid, kB1, + kU2, kU4, kU8, kU16, kU32, kU64, + kS2, kS4, kS8, kS16, @@ -103,11 +109,13 @@ enum class NumericTypeID { kCF32, kCTF32, kCF64, + kCS2, kCS4, kCS8, kCS16, kCS32, kCS64, + kCU2, kCU4, kCU8, kCU16, @@ -116,7 +124,7 @@ enum class NumericTypeID { kInvalid }; -/// Enumeraed type describing a transformation on a complex value. +/// Enumerated type describing a transformation on a complex value. enum class ComplexTransform { kNone, kConjugate, @@ -139,6 +147,7 @@ enum class Provider { enum class OperationKind { kGemm, kEqGemm, + kSparseGemm, kReduction, kInvalid }; @@ -164,6 +173,7 @@ enum class OpcodeClassID { kSimt, kTensorOp, kWmmaTensorOp, + kSparseTensorOp, kInvalid }; @@ -171,6 +181,8 @@ enum class MathOperationID { kAdd, kMultiplyAdd, kMultiplyAddSaturate, + kMultiplyAddFastBF16, + kMultiplyAddFastF16, kMultiplyAddComplex, kMultiplyAddGaussianComplex, kXorPopc, @@ -182,8 +194,7 @@ enum class MathOperationID { /// Enumeration indicating what kind of GEMM operation to perform enum class GemmKind { kGemm, - kBatched, - kArray, + kSparse, kUniversal, kPlanarComplex, kPlanarComplexArray, @@ -392,6 +403,9 @@ struct GemmDescription : public OperationDescription { /// Describes the source and destination matrices TensorDescription C; + /// Describes the sparse meta matrices + TensorDescription E; + /// Describes the data type of the scalars passed to the epilogue NumericTypeID element_epilogue; @@ -428,6 +442,26 @@ struct GemmDescription : public OperationDescription { transform_B(transform_B) {} }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Desciprion for structured sparse GEMMs. +struct SparseGemmDescription : public GemmDescription { + + /// Description structure for structured sparse GEMM + SparseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const &A = TensorDescription(), + TensorDescription const &B = TensorDescription(), + TensorDescription const &C = TensorDescription(), + TensorDescription const &E = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + GemmDescription(gemm_kind, A, B, C, element_epilogue, split_k_mode, transform_A, transform_B) + {this->E = E;} +}; /// Description of all Reduction operations struct ReductionDescription : public OperationDescription { @@ -747,6 +781,48 @@ struct GemmPlanarComplexArrayArguments { ScalarPointerMode pointer_mode; }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// OperationKind: kSparseGemm +// + +/// Computes GEMM assumine one of the inputs has 2:4 structured sparsity. +struct SparseGemmConfiguration { + + GemmUniversalMode mode; + gemm::GemmCoord problem_size; + int batch_count; /// number of sparse matrix products in batch + + int64_t lda; /// leading dimension of A operand + int64_t ldb; /// leading dimension of B operand + int64_t ldc; /// leading dimension of C operand + int64_t ldd; /// leading dimension of D operand + int64_t lde; /// leading dimension of E operand (metadata matrix) + + int64_t batch_stride_A; // stride between matrices + int64_t batch_stride_B; // stride between matrices + int64_t batch_stride_C; // stride between matrices + int64_t batch_stride_D; // stride between matrices + int64_t batch_stride_E; // stride between matrices +}; + +/// Arguments for sparse GEMMs +struct SparseGemmArguments { + + void const *A; /// pointer to A matrix + void const *B; /// pointer to B matrix + void const *C; /// pointer to C matrix + void *D; /// pointer to D matrix + void const *E; /// pointer to E matric (metadata) + + void const *alpha; /// pointer to alpha scalar + void const *beta; /// pointer to beta scalar + ScalarPointerMode pointer_mode; /// enumerant indicating whether alpha/beta pointers are host + /// or device pointers. +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library diff --git a/tools/library/include/cutlass/library/manifest.h b/tools/library/include/cutlass/library/manifest.h index 54e51c1fd0..7adf0fbbce 100644 --- a/tools/library/include/cutlass/library/manifest.h +++ b/tools/library/include/cutlass/library/manifest.h @@ -48,7 +48,7 @@ namespace library { // Forward declaration class Manifest; -// init and insert all cutlass gemm and conv2d op in manifest object (procedurally generated using generator.py) +// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) void initialize_all(Manifest &manifest); ///////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/scripts/gemm_operation.py b/tools/library/scripts/gemm_operation.py index 66ecc05e69..0a76f36bf9 100644 --- a/tools/library/scripts/gemm_operation.py +++ b/tools/library/scripts/gemm_operation.py @@ -216,8 +216,6 @@ def __init__(self): def emit(self, operation): warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] - #warp_shape[2] = operation.tile_description.math_instruction.instruction_shape[2] - warp_shape[2] = operation.tile_description.threadblock_shape[2] epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) @@ -262,6 +260,86 @@ def emit(self, operation): ################################################################################################### +class EmitSparseGemmInstance: + ''' Responsible for emitting a CUTLASS template definition''' + + def __init__(self): + self.gemm_template = """ + // Gemm operator ${operation_name} + using Operation_${operation_name} = cutlass::gemm::device::SparseGemm< + ${element_a}, ${layout_a}, + ${element_b}, ${layout_b}, + ${element_c}, ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + false, + ${math_operation} + ${residual} + >; +""" + + def emit(self, operation): + + warp_shape = [operation.tile_description.threadblock_shape[idx] // operation.tile_description.warp_count[idx] for idx in range(3)] + + epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) + + residual = '' + + values = { + 'operation_name': operation.procedural_name(), + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[operation.A.layout], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[operation.B.layout], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[operation.C.layout], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str(operation.tile_description.threadblock_shape[0]), + 'threadblock_shape_n': str(operation.tile_description.threadblock_shape[1]), + 'threadblock_shape_k': str(operation.tile_description.threadblock_shape[2]), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str(operation.tile_description.math_instruction.instruction_shape[0]), + 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), + 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), + 'epilogue_functor': EpilogueFunctorTag[operation.epilogue_functor], + 'swizzling_functor': SwizzlingFunctorTag[operation.swizzling_functor], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], + 'residual': residual + } + + template = self.gemm_template + + return SubstituteTemplate(template, values) + +################################################################################################### + + # class EmitGemmUniversalInstance: ''' Responsible for emitting a CUTLASS template definition''' @@ -330,7 +408,6 @@ def emit(self, operation): warp_count = operation.tile_description.warp_count warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] - warp_shape[2] = operation.tile_description.threadblock_shape[2] epilogue_vector_length = int(min(operation.C.alignment * DataTypeSize[operation.C.element], 128) / DataTypeSize[operation.C.element]) @@ -552,6 +629,7 @@ def __init__(self, operation_path, configuration_name): self.instance_emitter = { GemmKind.Gemm: EmitGemmInstance, + GemmKind.Sparse: EmitSparseGemmInstance, GemmKind.Universal: EmitGemmUniversalInstance, GemmKind.PlanarComplex: EmitGemmPlanarComplexInstance, GemmKind.PlanarComplexArray: EmitGemmPlanarComplexArrayInstance @@ -559,6 +637,7 @@ def __init__(self, operation_path, configuration_name): self.gemm_kind_wrappers = { GemmKind.Gemm: 'GemmOperation', + GemmKind.Sparse: 'GemmSparseOperation', GemmKind.Universal: 'GemmUniversalOperation', GemmKind.PlanarComplex: 'GemmPlanarComplexOperation', GemmKind.PlanarComplexArray: 'GemmPlanarComplexArrayOperation' @@ -571,6 +650,11 @@ def __init__(self, operation_path, configuration_name): ${compile_guard_start} manifest.append(new ${gemm_kind}("${operation_name}")); ${compile_guard_end} +""", + GemmKind.Sparse: """ +${compile_guard_start} + manifest.append(new ${gemm_kind}("${operation_name}")); +${compile_guard_end} """, GemmKind.Universal: """ ${compile_guard_start} diff --git a/tools/library/scripts/generator.py b/tools/library/scripts/generator.py index 2957864568..f21acaaf6e 100644 --- a/tools/library/scripts/generator.py +++ b/tools/library/scripts/generator.py @@ -69,6 +69,44 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ return operations +# +def CreateSparseGemmOperator(manifest, layouts, tile_descriptions, data_type, \ + alignment_constraints, complex_transforms = None, epilogue_functor = EpilogueFunctor.LinearCombination, \ + swizzling_functor = SwizzlingFunctor.Identity8): + + if complex_transforms is None: + complex_transforms = [(ComplexTransform.none, ComplexTransform.none),] + + element_a, element_b, element_c, element_epilogue = data_type + + gemm_kinds = [GemmKind.Sparse] + + operations = [] + + # by default, only generate the largest tile and largest alignment + if manifest.args.kernels == '': + tile_descriptions = [tile_descriptions[0],] + alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment, complex_transform[0]) + B = TensorDescription(element_b, layout[1], alignment, complex_transform[1]) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GemmOperation(GemmKind.Sparse, tile_description.minimum_compute_capability, \ + tile_description, A, B, C, element_epilogue, epilogue_functor, swizzling_functor) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + # def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_type, \ alignment_constraints, complex_transforms): @@ -102,7 +140,7 @@ def CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, data_t tile_description, A, B, C, element_epilogue)) return -################################################################################################### +########################################################################################################### ################################################################################################### ################################################################################################### @@ -152,6 +190,7 @@ def GenerateSM50_Simt(manifest, args): CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) + # # @@ -463,9 +502,52 @@ def GenerateSM70_WmmaTensorOp_161616(manifest, args): data_type_mixed, alignment_constraints) # +################################################################################################## +# +def GenerateSM70_Simt_complex(manifest, args): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 70 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + +# + def GenerateSM70(manifest, args): GenerateSM70_TensorOp_884(manifest, args) GenerateSM70_PlanarComplexTensorOp_884(manifest, args) + GenerateSM70_Simt_complex(manifest, args) # To limit build size, WMMA GEMMs are disabled for now. # @@ -662,7 +744,7 @@ def GenerateSM75_TensorOp_8816_TN(manifest, args): data_type_mixed = [ math_inst.element_a, math_inst.element_b, - DataType.s8, + math_inst.element_a, DataType.f32, ] @@ -720,7 +802,7 @@ def GenerateSM75_TensorOp_8816_Interleaved(manifest, args): data_type_mixed = [ math_inst.element_a, math_inst.element_b, - DataType.s8, + math_inst.element_a, DataType.f32, ] @@ -786,7 +868,7 @@ def GenerateSM75_TensorOp_8832_TN(manifest, args): data_type_mixed = [ math_inst.element_a, math_inst.element_b, - DataType.s4, + math_inst.element_a, DataType.f32, ] @@ -849,7 +931,7 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args): data_type_mixed = [ math_inst.element_a, math_inst.element_b, - DataType.s4, + math_inst.element_a, DataType.f32, ] @@ -861,6 +943,46 @@ def GenerateSM75_TensorOp_8832_Interleaved(manifest, args): # +# +def GenerateSM75_TensorOp_88128(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + + math_instructions = [ + MathInstruction( \ + [8, 8, 128], \ + DataType.b1, DataType.b1, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.xor_popc), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [128,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 512], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 512], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 512], 2, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] + + CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + +# + # def GenerateSM75_WmmaTensorOp_161616(manifest, args): @@ -920,6 +1042,40 @@ def GenerateSM75_WmmaTensorOp_161616(manifest, args): # # +def GenerateSM75_Simt_complex(manifest, args): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 75 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc) + ] + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + +# + def GenerateSM75(manifest, args): GenerateSM75_TensorOp_1688(manifest, args) GenerateSM75_PlanarComplexTensorOp_1688(manifest, args) @@ -927,7 +1083,10 @@ def GenerateSM75(manifest, args): GenerateSM75_TensorOp_8816_Interleaved(manifest, args) GenerateSM75_TensorOp_8832_TN(manifest, args) GenerateSM75_TensorOp_8832_Interleaved(manifest, args) + GenerateSM75_TensorOp_88128(manifest, args) #GenerateSM75_WmmaTensorOp_161616(manifest, args) + GenerateSM75_Simt_complex(manifest, args) + ################################################################################################### ################################################################################################### @@ -972,23 +1131,20 @@ def GenerateSM80_TensorOp_16816(manifest, args): tile_descriptions = [ TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 3, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 3, [2, 1, 2], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 4, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 4, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 4, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 5, [1, 2, 2], math_inst, min_cc, max_cc), TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] data_type = [ @@ -1016,6 +1172,84 @@ def GenerateSM80_TensorOp_16816(manifest, args): # +# +def GenerateSM80_SparseTensorOp_16832(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.f16, DataType.f16, DataType.f16, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + MathInstruction( \ + [16, 8, 32], \ + DataType.bf16, DataType.bf16, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8, 4, 2] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + +# + # def GenerateSM80_PlanarComplexTensorOp_16816(manifest, args): @@ -1119,26 +1353,26 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args): for math_inst in math_instructions: tile_descriptions = [ - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([64, 256, 128], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] - data_type = [math_inst.element_a, math_inst.element_b, DataType.s32, DataType.s32] - data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s8, DataType.f32] + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) @@ -1156,6 +1390,61 @@ def GenerateSM80_TensorOp_16832_TN(manifest, args): # +# +def GenerateSM80_SparseTensorOp_16864_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 64], \ + DataType.s8, DataType.s8, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [16,] + + tile_descriptions = [ + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s8, DataType.s8, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s8, DataType.s8, DataType.s8, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 16 + else: + op.C.alignment = 8 +# + # def GenerateSM80_TensorOp_16832_Interleaved(manifest, args): @@ -1186,15 +1475,17 @@ def GenerateSM80_TensorOp_16832_Interleaved(manifest, args): for math_inst in math_instructions: tile_descriptions = [ - TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 10, [2, 2, 1], math_inst, min_cc, max_cc), ] - data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s8, DataType.f32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) @@ -1234,22 +1525,26 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args): for math_inst in math_instructions: tile_descriptions = [ - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] - data_type = [math_inst.element_a, math_inst.element_b, DataType.s32, DataType.s32] - data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s4, DataType.f32] + data_type = [math_inst.element_a, math_inst.element_b, math_inst.element_accumulator, DataType.s32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) @@ -1268,6 +1563,63 @@ def GenerateSM80_TensorOp_16864_TN(manifest, args): op.C.alignment = 4 # +# +def GenerateSM80_SparseTensorOp_168128_TN(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + ] + + math_inst = \ + MathInstruction( \ + [16, 8, 128], \ + DataType.s4, DataType.s4, DataType.s32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add_saturate) + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [32,] + + tile_descriptions = [ + TileDescription([256, 128, 256], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 256], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 256], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 256], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 256], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 256], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 256], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 512], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.s4, DataType.s4, DataType.s32, DataType.s32] + data_type_mixed = [DataType.s4, DataType.s4, DataType.s4, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + operations = [] + + operations += CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombinationClamp) + + for op in operations: + if op.tile_description.threadblock_shape[1] >= 128: + op.C.alignment = 8 + elif op.tile_description.threadblock_shape[1] == 64: + op.C.alignment = 8 + else: + op.C.alignment = 4 +# + # def GenerateSM80_TensorOp_16864_Interleaved(manifest, args): @@ -1298,15 +1650,17 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, args): for math_inst in math_instructions: tile_descriptions = [ - TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 128], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 128], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 128], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 128], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 128], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 128], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 128], 10, [2, 2, 1], math_inst, min_cc, max_cc), ] - data_type_mixed = [math_inst.element_a, math_inst.element_b, DataType.s4, DataType.f32] + data_type_mixed = [math_inst.element_a, math_inst.element_b, math_inst.element_a, DataType.f32] operations = [] @@ -1342,18 +1696,22 @@ def GenerateSM80_TensorOp_168256(manifest, args): for math_inst in math_instructions: tile_descriptions = [ - TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 512], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 512], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 512], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 512], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 512], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 512], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 512], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 512], 10, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 1024], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 1024], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 1024], 4, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 1024], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 1024], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 1024], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 1024], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] data_type = [DataType.b1, DataType.b1, DataType.s32, DataType.s32] @@ -1393,23 +1751,20 @@ def GenerateSM80_TensorOp_1688(manifest, args): tile_descriptions = [ TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 3, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 3, [2, 1, 2], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 4, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 5, [1, 2, 2], math_inst, min_cc, max_cc), TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] data_type = [ @@ -1474,23 +1829,20 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args): tile_descriptions = [ TileDescription([256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 3, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 3, [2, 1, 2], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([128, 64, 32], 4, [2, 1, 2], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 4, [1, 2, 2], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 5, [1, 2, 2], math_inst, min_cc, max_cc), TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), - TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 256, 32], 3, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), ] data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] @@ -1500,6 +1852,55 @@ def GenerateSM80_TensorOp_1688_fast_math(manifest, args): # +# +def GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args): + + if not CudaToolkitVersionSatisfies(args.cuda_version, 11, 1): + return + + layouts = [ + (LayoutType.ColumnMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.ColumnMajor, LayoutType.RowMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.RowMajor), + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( \ + [16, 8, 16], \ + DataType.tf32, DataType.tf32, DataType.f32, \ + OpcodeClass.TensorOp, \ + MathOperation.multiply_add), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateSparseGemmOperator(manifest, layouts, tile_descriptions, \ + data_type, alignment_constraints) +# + # def GenerateSM80_TensorOp_1688_complex(manifest, args): @@ -1734,30 +2135,78 @@ def GenerateSM80_Simt(manifest, args): CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) + +# + +################################################################################################## +# +def GenerateSM80_Simt_complex(manifest, args): + math_instructions = [ + MathInstruction( \ + [1, 1, 1], \ + DataType.f32, DataType.f32, DataType.f32, \ + OpcodeClass.Simt, \ + MathOperation.multiply_add_complex), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [1,] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription([128, 128, 8], 5, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 128, 8], 4, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 64, 8], 3, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([64, 32, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([32, 32, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc), + ] + data_type = [ + DataType.cf32, + DataType.cf32, + DataType.cf32, + DataType.cf32 + ] + + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + (ComplexTransform.conj, ComplexTransform.none), + (ComplexTransform.none, ComplexTransform.conj), + (ComplexTransform.conj, ComplexTransform.conj) + ] + # ################################################################################################### # def GenerateSM80(manifest, args): - GenerateSM80_TensorOp_16816(manifest, args) + GenerateSM80_SparseTensorOp_16832(manifest, args) GenerateSM80_PlanarComplexTensorOp_16816(manifest, args) GenerateSM80_TensorOp_1688(manifest, args) GenerateSM80_TensorOp_1688_fast_math(manifest, args) + GenerateSM80_SparseTensorOp_16816_fast_math(manifest, args) GenerateSM80_TensorOp_1688_complex(manifest, args) GenerateSM80_TensorOp_884(manifest, args) GenerateSM80_TensorOp_884_complex(manifest, args) GenerateSM80_TensorOp_884_complex_gaussian(manifest, args) GenerateSM80_TensorOp_16832_TN(manifest, args) + GenerateSM80_SparseTensorOp_16864_TN(manifest, args) GenerateSM80_TensorOp_16832_Interleaved(manifest, args) GenerateSM80_TensorOp_16864_TN(manifest, args) + GenerateSM80_SparseTensorOp_168128_TN(manifest, args) GenerateSM80_TensorOp_16864_Interleaved(manifest, args) GenerateSM80_TensorOp_168256(manifest, args) GenerateSM80_Simt(manifest, args) -# + GenerateSM80_Simt_complex(manifest, args) ################################################################################################### +################################################################################################### if __name__ == "__main__": @@ -1768,7 +2217,11 @@ def GenerateSM80(manifest, args): parser.add_argument("--generator-target", default='library', help="Target of CUTLASS Library Generator.") parser.add_argument("--architectures", default='53;60;61;70;75;80', help="Target compute architectures") parser.add_argument("--kernels", default='', help='Comma delimited list to filter kernels by name.') + parser.add_argument("--ignore-kernels", default='', help='Comma delimited list of kernels to exclude from build.') parser.add_argument("--cuda-version", default="11.0.0", help="Semantic version string of CUDA Toolkit") + parser.add_argument('--kernel-filter-file', type=str, default=None, required=False, help='Full path of filter file') + parser.add_argument('--selected-kernel-list', type=str, default=None, required=False, + help='Specify the output log file containing all enabled kernels in this build') args = parser.parse_args() @@ -1780,9 +2233,13 @@ def GenerateSM80(manifest, args): GenerateSM70(manifest, args) GenerateSM75(manifest, args) GenerateSM80(manifest, args) - if 'library' in args.generator_target.split(','): manifest.emit(GeneratorTarget.Library) + if args.selected_kernel_list is not None: + if len(manifest.selected_kernels) > 0: + with open(args.selected_kernel_list, 'w') as file_writer: + for line in manifest.selected_kernels: + file_writer.write("%s\n" % line) # ################################################################################################### diff --git a/tools/library/scripts/library.py b/tools/library/scripts/library.py index bdc4348308..2bb062da95 100644 --- a/tools/library/scripts/library.py +++ b/tools/library/scripts/library.py @@ -265,6 +265,7 @@ class LayoutType(enum.Enum): ColumnMajorInterleaved64 = enum_auto() RowMajorInterleaved64 = enum_auto() TensorNHWC = enum_auto() + TensorNDHWC = enum_auto() TensorNCHW = enum_auto() TensorNGHWC = enum_auto() TensorNCxHW32 = enum_auto() @@ -279,6 +280,7 @@ class LayoutType(enum.Enum): LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', + LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', LayoutType.TensorNCxHW32: 'cutlass::layout::TensorNCxHW32', @@ -305,6 +307,7 @@ class LayoutType(enum.Enum): LayoutType.RowMajorInterleaved32: 't32', LayoutType.RowMajorInterleaved64: 't64', LayoutType.TensorNHWC: 'nhwc', + LayoutType.TensorNDHWC: 'ndhwc', LayoutType.TensorNCHW: 'nchw', LayoutType.TensorNGHWC: 'nghwc', LayoutType.TensorNCxHW32: 'ncxhw32', @@ -320,7 +323,6 @@ class LayoutType(enum.Enum): } ################################################################################################### - # class OpcodeClass(enum.Enum): Simt = enum_auto() @@ -383,8 +385,7 @@ def SubstituteTemplate(template, values): # class GemmKind(enum.Enum): Gemm = enum_auto() - Batched = enum_auto() - Array = enum_auto() + Sparse = enum_auto() Universal = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() @@ -392,8 +393,7 @@ class GemmKind(enum.Enum): # GemmKindNames = { GemmKind.Gemm: "gemm", - GemmKind.Batched: "gemm_batched", - GemmKind.Array: "gemm_array", + GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", diff --git a/tools/library/scripts/manifest.py b/tools/library/scripts/manifest.py index 756ddc7263..2f0aa24ecb 100644 --- a/tools/library/scripts/manifest.py +++ b/tools/library/scripts/manifest.py @@ -113,6 +113,7 @@ def __init__(self, args): self.operations = {} self.args = args self.compute_capabilities = [int(x) for x in args.architectures.split(';')] + self.selected_kernels = [] if args.operations == 'all': self.operations_enabled = [] @@ -129,6 +130,14 @@ def __init__(self, args): else: self.kernel_names = [x for x in args.kernels.split(',') if x != ''] + self.ignore_kernel_names = [x for x in args.ignore_kernels.split(',') if x != ''] + + if args.kernel_filter_file is None: + self.kernel_filter_list = [] + else: + self.kernel_filter_list = self.get_kernel_filters(args.kernel_filter_file) + + self.operation_count = 0 self.operations_by_name = {} self.top_level_prologue = ''' @@ -152,6 +161,29 @@ def __init__(self, args): } // namespace cutlass ''' + + + def get_kernel_filters (self, kernelListFile): + if os.path.isfile(kernelListFile): + with open(kernelListFile, 'r') as fileReader: + lines = [line.rstrip() for line in fileReader if not line.startswith("#")] + + lines = [re.compile(line) for line in lines if line] + return lines + else: + return [] + + + + def filter_out_kernels(self, kernel_name, kernel_filter_list): + + for kernel_filter_re in kernel_filter_list: + if kernel_filter_re.search(kernel_name) is not None: + return True + + return False + + # def _filter_string_matches(self, filter_string, haystack): ''' Returns true if all substrings appear in the haystack in order''' @@ -190,11 +222,25 @@ def filter(self, operation): if len(self.kernel_names): name = operation.procedural_name() enabled = False + + # compare against the include list for name_substr in self.kernel_names: if self._filter_string_matches(name_substr, name): enabled = True break + # compare against the exclude list + for name_substr in self.ignore_kernel_names: + if self._filter_string_matches(name_substr, name): + enabled = False + break + + if len(self.kernel_filter_list) > 0: + enabled = False + if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list): + enabled = True + + # todo: filter based on compute data type return enabled # @@ -208,6 +254,8 @@ def append(self, operation): ''' if self.filter(operation): + + self.selected_kernels.append(operation.procedural_name()) self.operations_by_name[operation.procedural_name()] = operation @@ -260,7 +308,8 @@ def emit(self, target = GeneratorTarget.Library): self.top_level_reserve, {'operation_count': str(self.operation_count)})) # for each operation kind, emit initializer for all configurations - for operation_kind, configurations in self.operations.items(): + for operation_kind, configurations in self.operations.items(): + with operation_emitters[target](generated_path, operation_kind, self.args) as operation_kind_emitter: for configuration_name, operations in configurations.items(): operation_kind_emitter.emit(configuration_name, operations) diff --git a/tools/library/src/gemm_operation.h b/tools/library/src/gemm_operation.h index 23781b25ed..d65e3414d5 100644 --- a/tools/library/src/gemm_operation.h +++ b/tools/library/src/gemm_operation.h @@ -27,10 +27,10 @@ */ #pragma once - #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_sparse.h" #include "cutlass/gemm/device/gemm_complex.h" #include "cutlass/gemm/device/gemm_batched.h" #include "cutlass/gemm/device/gemm_array.h" @@ -325,7 +325,7 @@ class GemmOperation : public GemmOperationBase { /////////////////////////////////////////////////////////////////////////////////////////////////// template -class GemmBatchedOperation : public GemmOperationBase { +class GemmSparseOperation : public GemmOperationBase { public: using Operator = Operator_; @@ -335,22 +335,21 @@ class GemmBatchedOperation : public GemmOperationBase { using LayoutB = typename Operator::LayoutB; using ElementC = typename Operator::ElementC; using LayoutC = typename Operator::LayoutC; + using ElementE = typename Operator::ElementE; + using LayoutE = typename Operator::LayoutE; using ElementAccumulator = typename Operator::ElementAccumulator; using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; using OperatorArguments = typename Operator::Arguments; -protected: - - /// - GemmDescription description_; - public: /// Constructor - GemmBatchedOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { - description_.gemm_kind = GemmKind::kBatched; + this->description_.kind = OperationKind::kSparseGemm; + this->description_.gemm_kind = GemmKind::kSparse; + this->description_.E = make_TensorDescription(Operator::kAlignmentE); } protected: @@ -358,19 +357,14 @@ class GemmBatchedOperation : public GemmOperationBase { /// Constructs the arguments structure given the configuration and arguments static Status construct_arguments_( OperatorArguments &operator_args, - GemmBatchedConfiguration const *configuration) { + SparseGemmConfiguration const *configuration) { operator_args.problem_size = configuration->problem_size; operator_args.ref_A = {nullptr, int(configuration->lda)}; - operator_args.stride_A = configuration->batch_stride_A; operator_args.ref_B = {nullptr, int(configuration->ldb)}; - operator_args.stride_B = configuration->batch_stride_B; operator_args.ref_C = {nullptr, int(configuration->ldc)}; - operator_args.stride_C = configuration->batch_stride_C; operator_args.ref_D = {nullptr, int(configuration->ldd)}; - operator_args.stride_D = configuration->batch_stride_D; - - operator_args.batch_count = configuration->batch_count; + operator_args.ref_E = {nullptr, int(configuration->lde)}; return Status::kSuccess; } @@ -378,14 +372,13 @@ class GemmBatchedOperation : public GemmOperationBase { /// Constructs the arguments structure given the configuration and arguments static Status update_arguments_( OperatorArguments &operator_args, - GemmBatchedArguments const *arguments) { + SparseGemmArguments const *arguments) { if (arguments->pointer_mode == ScalarPointerMode::kHost) { typename Operator::EpilogueOutputOp::Params params( *static_cast(arguments->alpha), *static_cast(arguments->beta) ); - operator_args.epilogue = params; } else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ @@ -403,27 +396,23 @@ class GemmBatchedOperation : public GemmOperationBase { operator_args.ref_B.reset(static_cast(arguments->B)); operator_args.ref_C.reset(static_cast(arguments->C)); operator_args.ref_D.reset(static_cast(arguments->D)); + operator_args.ref_E.reset(static_cast(arguments->E)); return Status::kSuccess; } public: - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - /// Returns success if the operation can proceed virtual Status can_implement( void const *configuration_ptr, void const *arguments_ptr) const { - GemmBatchedConfiguration const *configuration = - static_cast(configuration_ptr); + SparseGemmConfiguration const *configuration = + static_cast(configuration_ptr); - GemmBatchedArguments const *arguments = - static_cast(arguments_ptr); + SparseGemmArguments const *arguments = + static_cast(arguments_ptr); OperatorArguments args; @@ -457,7 +446,7 @@ class GemmBatchedOperation : public GemmOperationBase { Status status = construct_arguments_( args, - static_cast(configuration_ptr)); + static_cast(configuration_ptr)); if (status != Status::kSuccess) { return 0; @@ -477,7 +466,7 @@ class GemmBatchedOperation : public GemmOperationBase { Status status = construct_arguments_( args, - static_cast(configuration_ptr)); + static_cast(configuration_ptr)); if (status != Status::kSuccess) { return status; @@ -499,7 +488,7 @@ class GemmBatchedOperation : public GemmOperationBase { Status status = update_arguments_( args, - static_cast(arguments_ptr)); + static_cast(arguments_ptr)); if (status != Status::kSuccess) { return status; @@ -515,191 +504,26 @@ class GemmBatchedOperation : public GemmOperationBase { return op->run(stream); } -}; - - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -template -class GemmArrayOperation : public GemmOperationBase { -public: - - using Operator = Operator_; - using ElementA = typename Operator::ElementA; - using LayoutA = typename Operator::LayoutA; - using ElementB = typename Operator::ElementB; - using LayoutB = typename Operator::LayoutB; - using ElementC = typename Operator::ElementC; - using LayoutC = typename Operator::LayoutC; - using ElementAccumulator = typename Operator::ElementAccumulator; - using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; - - using OperatorArguments = typename Operator::Arguments; - -protected: - - /// - GemmDescription description_; - -public: - - /// Constructor - GemmArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { - - description_.gemm_kind = GemmKind::kArray; - } - -protected: - - /// Constructs the arguments structure given the configuration and arguments - static Status construct_arguments_( - OperatorArguments &operator_args, - GemmArrayConfiguration const *configuration) { - - operator_args.problem_size = configuration->problem_size; - - operator_args.batch_count = configuration->batch_count; - - return Status::kSuccess; - } - - /// Constructs the arguments structure given the configuration and arguments - static Status update_arguments_( - OperatorArguments &operator_args, - GemmArrayArguments const *arguments) { - - if (arguments->pointer_mode == ScalarPointerMode::kHost) { - typename Operator::EpilogueOutputOp::Params params( - *static_cast(arguments->alpha), - *static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ - typename Operator::EpilogueOutputOp::Params params( - static_cast(arguments->alpha), - static_cast(arguments->beta) - ); - operator_args.epilogue = params; - } - else { - return Status::kErrorInvalidProblem; - } - - return Status::kSuccess; - } - -public: - - /// Returns the description of the GEMM operation - virtual OperationDescription const & description() const { - return description_; - } - /// Returns success if the operation can proceed - virtual Status can_implement( - void const *configuration_ptr, - void const *arguments_ptr) const { - - GemmArrayConfiguration const *configuration = - static_cast(configuration_ptr); - - GemmArrayArguments const *arguments = - static_cast(arguments_ptr); - - OperatorArguments args; - - Status status = construct_arguments_(args, configuration); - - if (status != Status::kSuccess) { - return status; - } - - status = update_arguments_(args, arguments); - - if (status != Status::kSuccess) { - return status; - } - - return Operator::can_implement(args); - } - - /// Gets the host-side workspace - virtual uint64_t get_host_workspace_size( - void const *configuration) const { - - return sizeof(Operator); - } - - /// Gets the device-side workspace - virtual uint64_t get_device_workspace_size( - void const *configuration_ptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return 0; - } - - return Operator::get_workspace_size(args); - } - - /// Initializes the workspace - virtual Status initialize( - void const *configuration_ptr, - void *host_workspace, - void *device_workspace, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = construct_arguments_( - args, - static_cast(configuration_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = new (host_workspace) Operator; - - return op->initialize(args, device_workspace, stream); - } - - /// Runs the kernel - virtual Status run( - void const *arguments_ptr, - void *host_workspace, - void *device_workspace = nullptr, - cudaStream_t stream = nullptr) const { - - OperatorArguments args; - - Status status = update_arguments_( - args, - static_cast(arguments_ptr)); - - if (status != Status::kSuccess) { - return status; - } - - Operator *op = static_cast(host_workspace); - - status = op->update(args, device_workspace); - - if (status != Status::kSuccess) { - return status; - } - - return op->run(stream); + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif } }; -///////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// template class GemmUniversalOperation : public GemmOperationBase { @@ -742,7 +566,7 @@ class GemmUniversalOperation : public GemmOperationBase { operator_args.ldb = int(configuration->ldb); operator_args.ldc = int(configuration->ldc); operator_args.ldd = int(configuration->ldd); - + return Status::kSuccess; } diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index 73847b117f..21190cc825 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -47,6 +47,7 @@ #include "cutlass/layout/matrix.h" #include "cutlass/library/library.h" +#include "cutlass/library/arch_mappings.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -143,6 +144,14 @@ template <> struct MathOperationMap { static MathOperationID const kId = MathOperationID::kMultiplyAdd; }; +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; +}; + template <> struct MathOperationMap { static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; }; @@ -171,6 +180,22 @@ template <> struct LayoutMap { static LayoutTypeID const kId = LayoutTypeID::kRowMajor; }; +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; +}; + template <> struct LayoutMap> { static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; }; @@ -199,6 +224,9 @@ template <> struct LayoutMap { static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; }; +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; +}; ///////////////////////////////////////////////////////////////////////////////////////////////// template struct OpcodeClassMap; @@ -229,45 +257,6 @@ template <> struct ComplexTransformMap { ///////////////////////////////////////////////////////////////////////////////////////////////// -template struct ArchMap; - -template <> struct ArchMap { - static int const kMin = 50; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 60; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 61; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 70; - static int const kMax = 1024; -}; - -template <> struct ArchMap { - static int const kMin = 70; - static int const kMax = 75; -}; - -template struct ArchMap { - static int const kMin = 75; - static int const kMax = 1024; -}; - -template struct ArchMap { - static int const kMin = 80; - static int const kMax = 1024; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - template TensorDescription make_TensorDescription(int alignment = 1) { TensorDescription desc; diff --git a/tools/library/src/manifest.cpp b/tools/library/src/manifest.cpp index d4e8a884be..29d2ef156f 100644 --- a/tools/library/src/manifest.cpp +++ b/tools/library/src/manifest.cpp @@ -36,7 +36,6 @@ namespace cutlass { namespace library { ////////////////////////////////////////////////////////////////////////////////////////////////////////// - /// Top-level initialization Status Manifest::initialize() { diff --git a/tools/library/src/reference/gemm.cu b/tools/library/src/reference/gemm.cu new file mode 100644 index 0000000000..8e5361fd20 --- /dev/null +++ b/tools/library/src/reference/gemm.cu @@ -0,0 +1,335 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Instantiates GEMM reference implementations. +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +#include "gemm_reference_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_gemm_reference_operations(Manifest &manifest) { + + make_gemm_real_canonical_layouts< + float, // ElementA + float, // ElementB + float, // ElementC + float, // ElementScalar + float // ElementAccumulator + >(manifest); + + make_gemm_real_canonical_layouts< + tfloat32_t, + tfloat32_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + tfloat32_t, + tfloat32_t, + tfloat32_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + half_t, + half_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + half_t, + half_t, + half_t, + half_t + >(manifest); + + make_gemm_real_canonical_layouts< + half_t, + half_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + bfloat16_t, + bfloat16_t, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + bfloat16_t, + bfloat16_t, + float, + float, + float + >(manifest); + + make_gemm_real_canonical_layouts< + double, + double, + double, + double, + double + >(manifest); + + // + // Integer-valued GEMMs + // + + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int8_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + int8_t, + int8_t, + int32_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + uint8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + uint8_t, + int8_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_real_canonical_layouts< + uint8_t, + uint8_t, + int32_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + int8_t, + int8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 32, + int8_t, + int8_t, + int32_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + int8_t, + int8_t, + int8_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + int32_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + uint8_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 32, + uint8_t, + uint8_t, + int8_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int32_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + int4b_t, + int4b_t, + int4b_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int32_t, + int32_t, + int32_t + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int32_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + uint4b_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + make_gemm_interleaved_layouts< + 64, + uint4b_t, + uint4b_t, + int4b_t, + float, + int32_t, + NumericConverterClamp + >(manifest); + + // + // Complex-valued GEMMs + // + + make_gemm_complex_canonical_layouts< + complex, + complex, + complex, + complex, + complex + >(manifest); + + make_gemm_complex_canonical_layouts< + complex, + complex, + complex, + complex, + complex + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/gemm_reference_operation.h b/tools/library/src/reference/gemm_reference_operation.h new file mode 100644 index 0000000000..11a5230bbe --- /dev/null +++ b/tools/library/src/reference/gemm_reference_operation.h @@ -0,0 +1,472 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for GEMM operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class GemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + static cutlass::ComplexTransform const kTransformA = TransformA; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + static cutlass::ComplexTransform const kTransformB = TransformB; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.transform_A = ComplexTransformMap::kId; + description_.B = make_TensorDescription(); + description_.transform_B = ComplexTransformMap::kId; + description_.C = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + GemmUniversalConfiguration const &config = *static_cast(host_workspace); + GemmUniversalArguments const &args = *static_cast(arguments); + + ElementCompute alpha; + ElementCompute beta; + + alpha = *static_cast(args.alpha); + beta = *static_cast(args.beta); + + TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; + TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; + TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; + TensorRefC ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + config.problem_size, + alpha, + ref_A, + kTransformA, + ref_B, + kTransformB, + beta, + ref_C, + ref_D, + ElementAccumulator(), + config.batch_count, + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + + cutlass::reference::device::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + config.problem_size, + alpha, + ref_A, + kTransformA, + ref_B, + kTransformB, + beta, + ref_C, + ref_D, + ElementAccumulator(), + config.batch_count, + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + + return Status::kErrorNotSupported; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm(Manifest &manifest) { + + manifest.append(new GemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new GemmReferenceOperation< + Provider::kReferenceDevice, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +} + +/// Helper to create NN, NT, TN, and TT GEMM layouts. +template < + typename ElementA_, cutlass::ComplexTransform TransformA, + typename ElementB_, cutlass::ComplexTransform TransformB, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_canonical_layouts(Manifest &manifest) { + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + + +/// Helper to create TN and interleaved layouts GEMM layouts. +template < + int InterleaveK, + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_interleaved_layouts(Manifest &manifest) { + + make_gemm< + ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + +} + +/// Helper to real-valued GEMM with canonical layouts +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_real_canonical_layouts(Manifest &manifest) { + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +// Helper to create all complex transformation permutations +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_complex_canonical_layouts(Manifest &manifest) { + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu new file mode 100644 index 0000000000..016d91a6f2 --- /dev/null +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + +*/ + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +void initialize_gemm_reference_operations(Manifest &manifest); + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +void initialize_reference_operations(Manifest &manifest) { + initialize_gemm_reference_operations(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index 427f0a2c52..13fb9dfc0a 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -93,14 +93,13 @@ static struct { } GemmKind_enumerants[] = { {"gemm", "", GemmKind::kGemm}, - {"batched", "", GemmKind::kBatched}, - {"array", "", GemmKind::kArray}, + {"spgemm", "", GemmKind::kSparse}, {"universal", "", GemmKind::kUniversal}, {"planar_complex", "", GemmKind::kPlanarComplex}, {"planar_complex_array", "", GemmKind::kPlanarComplexArray}, }; -/// Converts a ConvKind enumerant to a string +/// Converts a GemmKind enumerant to a string char const *to_string(GemmKind type, bool pretty) { for (auto const & possible : GemmKind_enumerants) { @@ -117,6 +116,8 @@ char const *to_string(GemmKind type, bool pretty) { return pretty ? "Invalid" : "invalid"; } +/////////////////////////////////////////////////////////////////////////////////////////////////// + ///////////////////////////////////////////////////////////////////////////////////////////////// static struct { @@ -217,11 +218,13 @@ NumericTypeID_enumerants[] = { {"unknown", "", NumericTypeID::kUnknown}, {"void", "Void", NumericTypeID::kVoid}, {"b1", "B1", NumericTypeID::kB1}, + {"u2", "U2", NumericTypeID::kU2}, {"u4", "U4", NumericTypeID::kU4}, {"u8", "U8", NumericTypeID::kU8}, {"u16", "U16", NumericTypeID::kU16}, {"u32", "U32", NumericTypeID::kU32}, {"u64", "U64", NumericTypeID::kU64}, + {"s2", "S2", NumericTypeID::kS2}, {"s4", "S4", NumericTypeID::kS4}, {"s8", "S8", NumericTypeID::kS8}, {"s16", "S16", NumericTypeID::kS16}, @@ -237,11 +240,13 @@ NumericTypeID_enumerants[] = { {"cf32", "CF32", NumericTypeID::kCF32}, {"ctf32", "CTF32", NumericTypeID::kCTF32}, {"cf64", "CF64", NumericTypeID::kCF64}, + {"cu2", "CU2", NumericTypeID::kCU2}, {"cu4", "CU4", NumericTypeID::kCU4}, {"cu8", "CU8", NumericTypeID::kCU8}, {"cu16", "CU16", NumericTypeID::kCU16}, {"cu32", "CU32", NumericTypeID::kCU32}, {"cu64", "CU64", NumericTypeID::kCU64}, + {"cs2", "CS2", NumericTypeID::kCS2}, {"cs4", "CS4", NumericTypeID::kCS4}, {"cs8", "CS8", NumericTypeID::kCS8}, {"cs16", "CS16", NumericTypeID::kCS16}, @@ -296,11 +301,13 @@ int sizeof_bits(NumericTypeID type) { case NumericTypeID::kCF32: return 64; case NumericTypeID::kCTF32: return 64; case NumericTypeID::kCF64: return 128; + case NumericTypeID::kS2: return 2; case NumericTypeID::kS4: return 4; case NumericTypeID::kS8: return 8; case NumericTypeID::kS16: return 16; case NumericTypeID::kS32: return 32; case NumericTypeID::kS64: return 64; + case NumericTypeID::kU2: return 2; case NumericTypeID::kU4: return 4; case NumericTypeID::kU8: return 8; case NumericTypeID::kU16: return 16; @@ -341,11 +348,13 @@ NumericTypeID get_real_type(NumericTypeID type) { /// Returns true if numeric type is integer bool is_integer_type(NumericTypeID type) { switch (type) { + case NumericTypeID::kS2: return true; case NumericTypeID::kS4: return true; case NumericTypeID::kS8: return true; case NumericTypeID::kS16: return true; case NumericTypeID::kS32: return true; case NumericTypeID::kS64: return true; + case NumericTypeID::kU2: return true; case NumericTypeID::kU4: return true; case NumericTypeID::kU8: return true; case NumericTypeID::kU16: return true; @@ -364,6 +373,7 @@ bool is_signed_type(NumericTypeID type) { case NumericTypeID::kTF32: return true; case NumericTypeID::kF32: return true; case NumericTypeID::kF64: return true; + case NumericTypeID::kS2: return true; case NumericTypeID::kS4: return true; case NumericTypeID::kS8: return true; case NumericTypeID::kS16: return true; @@ -415,6 +425,12 @@ layout_aliases[] = { {LayoutTypeID::kColumnMajor, "column"}, {LayoutTypeID::kColumnMajor, "col"}, {LayoutTypeID::kColumnMajor, "n"}, + + {LayoutTypeID::kColumnMajorInterleavedK2, "nk2"}, + {LayoutTypeID::kRowMajorInterleavedK2, "tk2"}, + + {LayoutTypeID::kColumnMajorInterleavedK4, "nk4"}, + {LayoutTypeID::kRowMajorInterleavedK4, "tk4"}, {LayoutTypeID::kColumnMajorInterleavedK16, "nk16"}, {LayoutTypeID::kRowMajorInterleavedK16, "tk16"}, @@ -426,7 +442,10 @@ layout_aliases[] = { {LayoutTypeID::kRowMajorInterleavedK64, "tk64"}, {LayoutTypeID::kTensorNCHW, "nchw"}, + {LayoutTypeID::kTensorNCDHW, "ncdhw"}, {LayoutTypeID::kTensorNHWC, "nhwc"}, + {LayoutTypeID::kTensorNDHWC, "ndhwc"}, + {LayoutTypeID::kUnknown, "*"}, {LayoutTypeID::kInvalid, nullptr} }; @@ -457,6 +476,8 @@ int get_layout_stride_rank(LayoutTypeID layout_id) { switch (layout_id) { case LayoutTypeID::kColumnMajor: return cutlass::layout::ColumnMajor::kStrideRank; case LayoutTypeID::kRowMajor: return cutlass::layout::RowMajor::kStrideRank; + case LayoutTypeID::kColumnMajorInterleavedK2: + case LayoutTypeID::kRowMajorInterleavedK2: case LayoutTypeID::kColumnMajorInterleavedK4: case LayoutTypeID::kRowMajorInterleavedK4: case LayoutTypeID::kColumnMajorInterleavedK16: @@ -464,10 +485,10 @@ int get_layout_stride_rank(LayoutTypeID layout_id) { case LayoutTypeID::kColumnMajorInterleavedK32: case LayoutTypeID::kRowMajorInterleavedK32: case LayoutTypeID::kColumnMajorInterleavedK64: - case LayoutTypeID::kRowMajorInterleavedK64: - return 1; + case LayoutTypeID::kRowMajorInterleavedK64: return 1; case LayoutTypeID::kTensorNCHW: case LayoutTypeID::kTensorNHWC: return 3; + case LayoutTypeID::kTensorNDHWC: return 4; default : throw std::runtime_error("Unsupported LayoutTypeID in LayoutType::get_stride_rank"); } } @@ -969,12 +990,12 @@ bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t sr break; case NumericTypeID::kCF32: { - *reinterpret_cast*>(bytes.data()) = std::complex(float(src), float(0)); + *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float(0)); } break; case NumericTypeID::kCF64: { - *reinterpret_cast*>(bytes.data()) = std::complex(double(src), double(0)); + *reinterpret_cast*>(bytes.data()) = cutlass::complex(double(src), double(0)); } break; default: @@ -1177,17 +1198,17 @@ bool cast_from_double(std::vector &bytes, NumericTypeID type, double sr break; case NumericTypeID::kCF32: { - *reinterpret_cast*>(bytes.data()) = std::complex(float(src), float(0)); + *reinterpret_cast*>(bytes.data()) = cutlass::complex(float(src), float()); } break; case NumericTypeID::kCTF32: { - *reinterpret_cast*>(bytes.data()) = std::complex(tfloat32_t(src), tfloat32_t(0)); + *reinterpret_cast*>(bytes.data()) = cutlass::complex(tfloat32_t(src), tfloat32_t()); } break; case NumericTypeID::kCF64: { - *reinterpret_cast*>(bytes.data()) = std::complex(src, double(0)); + *reinterpret_cast*>(bytes.data()) = cutlass::complex(src, double()); } break; default: diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index a47c831415..52baacb1aa 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -37,6 +37,7 @@ set(CUTLASS_TOOLS_PROFILER_SOURCES src/problem_space.cpp src/operation_profiler.cu src/gemm_operation_profiler.cu + src/sparse_gemm_operation_profiler.cu ) # diff --git a/tools/profiler/src/cublas_helpers.cpp b/tools/profiler/src/cublas_helpers.cpp index 05262a22de..3369d9615a 100644 --- a/tools/profiler/src/cublas_helpers.cpp +++ b/tools/profiler/src/cublas_helpers.cpp @@ -51,6 +51,18 @@ Status get_cutlass_status(cublasStatus_t cublas) { return Status::kErrorInternal; } +/// Converts a cuBLASS status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cublasStatus_t cublas_status) { + + if (cublas_status == CUBLAS_STATUS_INVALID_VALUE) { + return Disposition::kInvalidProblem; + } + else if (cublas_status == CUBLAS_STATUS_NOT_SUPPORTED) { + return Disposition::kNotSupported; + } + return Disposition::kFailed; +} + /// Maps a CUTLASS tensor layout to a cuBLAS transpose operation bool get_cublas_transpose_operation( cublasOperation_t &operation, @@ -156,7 +168,6 @@ bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID ele /// Gets the cublas algorithm given threadblock tile dimensions and math opcode class cublasGemmAlgo_t get_cublas_gemm_algo(int cta_m, int cta_n, int cta_k, library::OpcodeClassID opcode_class) { - // TODO return (opcode_class == library::OpcodeClassID::kSimt ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP); } @@ -252,37 +263,69 @@ cublasGemmExDispatcher::cublasGemmExDispatcher( /// Executes GEMM using these arguments cublasStatus_t cublasGemmExDispatcher::operator()(cublasHandle_t handle) { - return cublasGemmEx( - handle, - trans_A, - trans_B, - configuration.problem_size.m(), - configuration.problem_size.n(), - configuration.problem_size.k(), - arguments.alpha, - arguments.A, - data_type_A, - int(configuration.lda), - arguments.B, - data_type_B, - int(configuration.ldb), - arguments.beta, - arguments.D, - data_type_C, - int(configuration.ldc), -#if (__CUDA_VER_MAJOR__ >= 11) - compute_type, -#else - compute_data_type, -#endif - algo - ); + if (configuration.mode == library::GemmUniversalMode::kBatched) { + return cublasGemmStridedBatchedEx( + handle, + trans_A, + trans_B, + configuration.problem_size.m(), + configuration.problem_size.n(), + configuration.problem_size.k(), + arguments.alpha, + arguments.A, + data_type_A, + int(configuration.lda), + arguments.batch_stride_A, + arguments.B, + data_type_B, + int(configuration.ldb), + arguments.batch_stride_B, + arguments.beta, + arguments.D, + data_type_C, + int(configuration.ldc), + arguments.batch_stride_C, + configuration.batch_count, + #if (__CUDA_VER_MAJOR__ >= 11) + compute_type, + #else + compute_data_type, + #endif + algo + ); + } + else { + return cublasGemmEx( + handle, + trans_A, + trans_B, + configuration.problem_size.m(), + configuration.problem_size.n(), + configuration.problem_size.k(), + arguments.alpha, + arguments.A, + data_type_A, + int(configuration.lda), + arguments.B, + data_type_B, + int(configuration.ldb), + arguments.beta, + arguments.D, + data_type_C, + int(configuration.ldc), + #if (__CUDA_VER_MAJOR__ >= 11) + compute_type, + #else + compute_data_type, + #endif + algo + ); + } } - -///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace detail +///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace profiler } // namespace cutlass diff --git a/tools/profiler/src/cublas_helpers.h b/tools/profiler/src/cublas_helpers.h index 9c8078466a..c2bf13b5f7 100644 --- a/tools/profiler/src/cublas_helpers.h +++ b/tools/profiler/src/cublas_helpers.h @@ -47,6 +47,9 @@ namespace profiler { /// Converts a cuBLAS status to cutlass::Status Status get_cutlass_status(cublasStatus_t cublas); +/// Converts a cuBLASS status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cublasStatus_t cublas_status); + /// Maps a CUTLASS tensor layout to a cuBLAS transpose operation bool get_cublas_transpose_operation( cublasOperation_t &operation, diff --git a/tools/profiler/src/cutlass_profiler.cu b/tools/profiler/src/cutlass_profiler.cu index 90f4a95970..9934ff4cd6 100644 --- a/tools/profiler/src/cutlass_profiler.cu +++ b/tools/profiler/src/cutlass_profiler.cu @@ -32,6 +32,8 @@ // Profiler includes #include "cutlass_profiler.h" #include "gemm_operation_profiler.h" +#include "sparse_gemm_operation_profiler.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -46,6 +48,8 @@ CutlassProfiler::CutlassProfiler( operation_profilers_.emplace_back(new GemmOperationProfiler(options)); + operation_profilers_.emplace_back(new SparseGemmOperationProfiler(options)); + } CutlassProfiler::~CutlassProfiler() { diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 4045abfeec..777fb4d0aa 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -94,6 +94,12 @@ std::vector DeviceAllocation::get_packed_layout( case library::LayoutTypeID::kRowMajor: stride = get_packed_layout_stride(extent); break; + case library::LayoutTypeID::kColumnMajorInterleavedK2: + stride = get_packed_layout_stride>(extent); + break; + case library::LayoutTypeID::kRowMajorInterleavedK2: + stride = get_packed_layout_stride>(extent); + break; case library::LayoutTypeID::kColumnMajorInterleavedK4: stride = get_packed_layout_stride>(extent); break; @@ -124,7 +130,9 @@ std::vector DeviceAllocation::get_packed_layout( case library::LayoutTypeID::kTensorNHWC: stride = get_packed_layout_stride(extent); break; - + case library::LayoutTypeID::kTensorNDHWC: + stride = get_packed_layout_stride(extent); + break; default: break; } @@ -200,6 +208,12 @@ size_t DeviceAllocation::construct_layout( case library::LayoutTypeID::kRowMajor: return construct_layout_(bytes, layout_id, extent, stride); + case library::LayoutTypeID::kColumnMajorInterleavedK2: + return construct_layout_>(bytes, layout_id, extent, stride); + + case library::LayoutTypeID::kRowMajorInterleavedK2: + return construct_layout_>(bytes, layout_id, extent, stride); + case library::LayoutTypeID::kColumnMajorInterleavedK4: return construct_layout_>(bytes, layout_id, extent, stride); @@ -230,6 +244,9 @@ size_t DeviceAllocation::construct_layout( case library::LayoutTypeID::kTensorNHWC: return construct_layout_(bytes, layout_id, extent, stride); + case library::LayoutTypeID::kTensorNDHWC: + return construct_layout_(bytes, layout_id, extent, stride); + default: break; } @@ -240,9 +257,11 @@ size_t DeviceAllocation::construct_layout( DeviceAllocation::DeviceAllocation(): type_(library::NumericTypeID::kInvalid), + batch_stride_(0), capacity_(0), pointer_(nullptr), - layout_(library::LayoutTypeID::kUnknown) { + layout_(library::LayoutTypeID::kUnknown), + batch_count_(1) { } @@ -250,7 +269,8 @@ DeviceAllocation::DeviceAllocation( library::NumericTypeID type, size_t capacity ): - type_(type), capacity_(capacity), pointer_(nullptr), layout_(library::LayoutTypeID::kUnknown) { + type_(type), batch_stride_(capacity), capacity_(capacity), pointer_(nullptr), + layout_(library::LayoutTypeID::kUnknown), batch_count_(1) { cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity)); @@ -266,11 +286,12 @@ DeviceAllocation::DeviceAllocation( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride + std::vector const &stride, + int batch_count ): - type_(type), capacity_(size_t(0)), pointer_(nullptr) { + type_(type), batch_stride_(size_t(0)), capacity_(size_t(0)), pointer_(nullptr), batch_count_(1) { - reset(type, layout_id, extent, stride); + reset(type, layout_id, extent, stride, batch_count); } DeviceAllocation::~DeviceAllocation() { @@ -285,12 +306,14 @@ DeviceAllocation &DeviceAllocation::reset() { } type_ = library::NumericTypeID::kInvalid; + batch_stride_ = 0; capacity_ = 0; pointer_ = nullptr; layout_ = library::LayoutTypeID::kUnknown; stride_.clear(); extent_.clear(); tensor_ref_buffer_.clear(); + batch_count_ = 1; return *this; } @@ -299,16 +322,19 @@ DeviceAllocation &DeviceAllocation::reset(library::NumericTypeID type, size_t ca reset(); - cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity)); + type_ = type; + batch_stride_ = capacity; + capacity_ = capacity; + + cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type_, capacity_)); if (result != cudaSuccess) { throw std::bad_alloc(); } - type_ = type; - capacity_ = capacity; layout_ = library::LayoutTypeID::kUnknown; stride_.clear(); extent_.clear(); + batch_count_ = 1; tensor_ref_buffer_.resize(sizeof(pointer_), 0); std::memcpy(tensor_ref_buffer_.data(), &pointer_, sizeof(pointer_)); @@ -321,7 +347,8 @@ DeviceAllocation &DeviceAllocation::reset( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride) { + std::vector const &stride, + int batch_count) { reset(); @@ -332,13 +359,16 @@ DeviceAllocation &DeviceAllocation::reset( layout_ = layout_id; stride_ = stride; extent_ = extent; + batch_count_ = batch_count; - capacity_ = construct_layout( + batch_stride_ = construct_layout( tensor_ref_buffer_.data() + sizeof(pointer_), layout_id, extent, stride_); + capacity_ = batch_stride_ * batch_count_; + cudaError_t result = cudaMalloc((void **)&pointer_, bytes(type, capacity_)); if (result != cudaSuccess) { throw std::bad_alloc(); @@ -361,6 +391,10 @@ void *DeviceAllocation::data() const { return pointer_; } +void *DeviceAllocation::batch_data(int batch_idx) const { + return static_cast(data()) + batch_stride_bytes() * batch_idx; +} + library::LayoutTypeID DeviceAllocation::layout() const { return layout_; } @@ -374,6 +408,21 @@ std::vector const & DeviceAllocation::extent() const { return extent_; } +/// Gets the number of adjacent tensors in memory +int DeviceAllocation::batch_count() const { + return batch_count_; +} + +/// Gets the stride (in units of elements) beteween items +int64_t DeviceAllocation::batch_stride() const { + return batch_stride_; +} + +/// Gets the stride (in units of bytes) beteween items +int64_t DeviceAllocation::batch_stride_bytes() const { + return bytes(type_, batch_stride_); +} + size_t DeviceAllocation::capacity() const { return capacity_; } @@ -423,6 +472,22 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kBF16: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kTF32: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kF32: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -431,6 +496,22 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCBF16: + cutlass::reference::device::BlockFillRandom>( + reinterpret_cast *>(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kCTF32: + cutlass::reference::device::BlockFillRandom>( + reinterpret_cast *>(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kCF32: cutlass::reference::device::BlockFillRandom>( reinterpret_cast *>(pointer_), @@ -455,6 +536,22 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kS2: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kS4: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kS8: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -487,6 +584,30 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kB1: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kU2: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kU4: + cutlass::reference::device::BlockFillRandom( + reinterpret_cast(pointer_), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kU8: cutlass::reference::device::BlockFillRandom( reinterpret_cast(pointer_), @@ -523,7 +644,6 @@ void DeviceAllocation::initialize_random_device(int seed, Distribution dist) { } } - void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { if (!good()) { throw std::runtime_error("Attempting to initialize invalid allocation."); @@ -540,6 +660,22 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kBF16: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kTF32: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kF32: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -556,6 +692,22 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kCBF16: + cutlass::reference::host::BlockFillRandom>( + reinterpret_cast *>(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kCTF32: + cutlass::reference::host::BlockFillRandom>( + reinterpret_cast *>(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kCF32: cutlass::reference::host::BlockFillRandom>( reinterpret_cast *>(host_data.data()), @@ -580,6 +732,22 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kS2: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kS4: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kS8: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -612,6 +780,30 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { dist ); break; + case library::NumericTypeID::kB1: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kU2: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; + case library::NumericTypeID::kU4: + cutlass::reference::host::BlockFillRandom( + reinterpret_cast(host_data.data()), + capacity_, + seed, + dist + ); + break; case library::NumericTypeID::kU8: cutlass::reference::host::BlockFillRandom( reinterpret_cast(host_data.data()), @@ -650,6 +842,67 @@ void DeviceAllocation::initialize_random_host(int seed, Distribution dist) { copy_from_host(host_data.data()); } +void DeviceAllocation::initialize_random_sparsemeta_device(int seed, int MetaSizeInBits) { + if (!good()) { + throw std::runtime_error("Attempting to initialize invalid allocation."); + } + + // Instantiate calls to CURAND here. This file takes a long time to compile for + // this reason. + + switch (type_) { + case library::NumericTypeID::kU16: + cutlass::reference::device::BlockFillRandomSparseMeta( + reinterpret_cast(pointer_), + capacity_, + seed, + MetaSizeInBits + ); + break; + case library::NumericTypeID::kU32: + cutlass::reference::device::BlockFillRandomSparseMeta( + reinterpret_cast(pointer_), + capacity_, + seed, + MetaSizeInBits + ); + break; + default: + break; + } +} + +void DeviceAllocation::initialize_random_sparsemeta_host(int seed, int MetaSizeInBits) { + if (!good()) { + throw std::runtime_error("Attempting to initialize invalid allocation."); + } + + std::vector host_data(bytes()); + + switch (type_) { + case library::NumericTypeID::kS16: + cutlass::reference::host::BlockFillRandomSparseMeta( + reinterpret_cast(host_data.data()), + capacity_, + seed, + MetaSizeInBits + ); + break; + case library::NumericTypeID::kS32: + cutlass::reference::host::BlockFillRandomSparseMeta( + reinterpret_cast(host_data.data()), + capacity_, + seed, + MetaSizeInBits + ); + break; + default: + break; + } + + copy_from_host(host_data.data()); +} + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Returns true if two blocks have exactly the same value @@ -666,6 +919,18 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_B), capacity); + case library::NumericTypeID::kBF16: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + + case library::NumericTypeID::kTF32: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + case library::NumericTypeID::kF32: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -684,6 +949,18 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast const *>(ptr_B), capacity); + case library::NumericTypeID::kCBF16: + return reference::device::BlockCompareEqual>( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + + case library::NumericTypeID::kCTF32: + return reference::device::BlockCompareEqual>( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + case library::NumericTypeID::kF64: return reference::device::BlockCompareEqual( reinterpret_cast(ptr_A), @@ -695,6 +972,18 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast const *>(ptr_A), reinterpret_cast const *>(ptr_B), capacity); + + case library::NumericTypeID::kS2: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + + case library::NumericTypeID::kS4: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); case library::NumericTypeID::kS8: return reference::device::BlockCompareEqual( @@ -719,6 +1008,24 @@ bool DeviceAllocation::block_compare_equal( reinterpret_cast(ptr_A), reinterpret_cast(ptr_B), capacity); + + case library::NumericTypeID::kB1: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + + case library::NumericTypeID::kU2: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); + + case library::NumericTypeID::kU4: + return reference::device::BlockCompareEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity); case library::NumericTypeID::kU8: return reference::device::BlockCompareEqual( @@ -767,6 +1074,22 @@ bool DeviceAllocation::block_compare_relatively_equal( static_cast(epsilon), static_cast(nonzero_floor)); + case library::NumericTypeID::kBF16: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kTF32: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + case library::NumericTypeID::kF32: return reference::device::BlockCompareRelativelyEqual( reinterpret_cast(ptr_A), @@ -782,6 +1105,22 @@ bool DeviceAllocation::block_compare_relatively_equal( capacity, static_cast(epsilon), static_cast(nonzero_floor)); + + case library::NumericTypeID::kS2: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kS4: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); case library::NumericTypeID::kS8: return reference::device::BlockCompareRelativelyEqual( @@ -814,6 +1153,30 @@ bool DeviceAllocation::block_compare_relatively_equal( capacity, static_cast(epsilon), static_cast(nonzero_floor)); + + case library::NumericTypeID::kB1: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kU2: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); + + case library::NumericTypeID::kU4: + return reference::device::BlockCompareRelativelyEqual( + reinterpret_cast(ptr_A), + reinterpret_cast(ptr_B), + capacity, + static_cast(epsilon), + static_cast(nonzero_floor)); case library::NumericTypeID::kU8: return reference::device::BlockCompareRelativelyEqual( @@ -852,6 +1215,12 @@ bool DeviceAllocation::block_compare_relatively_equal( // As a simplification, we can require bitwise equality. This avoids false positives. // (i.e. "pass" really means passing. "Fail" may not actually mean failure given appropriate epsilon.) // + case library::NumericTypeID::kCF16: + return reference::device::BlockCompareEqual >( + reinterpret_cast const *>(ptr_A), + reinterpret_cast const *>(ptr_B), + capacity); + case library::NumericTypeID::kCF32: return reference::device::BlockCompareEqual >( reinterpret_cast const *>(ptr_A), @@ -865,7 +1234,9 @@ bool DeviceAllocation::block_compare_relatively_equal( capacity); default: - throw std::runtime_error("Unsupported numeric type"); + { + throw std::runtime_error(std::string("Unsupported numeric type: ") + to_string(numeric_type)); + } } } @@ -928,13 +1299,17 @@ static void write_tensor_csv_static_tensor_view( Layout layout(stride); HostTensor host_tensor(extent, layout, false); - if (host_tensor.capacity() != allocation.capacity()) { + if (host_tensor.capacity() != allocation.batch_stride()) { throw std::runtime_error("Unexpected capacity to equal."); } - host_tensor.copy_in_device_to_host(static_cast(allocation.data()), host_tensor.capacity()); + host_tensor.copy_in_device_to_host( + static_cast(allocation.data()), + allocation.batch_stride()); TensorViewWrite(out, host_tensor.host_view()); + + out << "\n\n"; } ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -951,9 +1326,42 @@ static void write_tensor_csv_static_type( case library::LayoutTypeID::kColumnMajor: write_tensor_csv_static_tensor_view(out, allocation); break; + case library::LayoutTypeID::kRowMajorInterleavedK2: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kColumnMajorInterleavedK2: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kRowMajorInterleavedK4: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kColumnMajorInterleavedK4: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kRowMajorInterleavedK16: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kColumnMajorInterleavedK16: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kRowMajorInterleavedK32: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kColumnMajorInterleavedK32: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kRowMajorInterleavedK64: + write_tensor_csv_static_tensor_view>(out, allocation); + break; + case library::LayoutTypeID::kColumnMajorInterleavedK64: + write_tensor_csv_static_tensor_view>(out, allocation); + break; case library::LayoutTypeID::kTensorNHWC: write_tensor_csv_static_tensor_view(out, allocation); break; + case library::LayoutTypeID::kTensorNDHWC: + write_tensor_csv_static_tensor_view(out, allocation); + break; default: throw std::runtime_error("Unhandled layout"); } @@ -970,6 +1378,14 @@ void DeviceAllocation::write_tensor_csv( write_tensor_csv_static_type(out, *this); break; + case library::NumericTypeID::kBF16: + write_tensor_csv_static_type(out, *this); + break; + + case library::NumericTypeID::kTF32: + write_tensor_csv_static_type(out, *this); + break; + case library::NumericTypeID::kF32: write_tensor_csv_static_type(out, *this); break; @@ -977,6 +1393,14 @@ void DeviceAllocation::write_tensor_csv( case library::NumericTypeID::kF64: write_tensor_csv_static_type(out, *this); break; + + case library::NumericTypeID::kS2: + write_tensor_csv_static_type(out, *this); + break; + + case library::NumericTypeID::kS4: + write_tensor_csv_static_type(out, *this); + break; case library::NumericTypeID::kS8: write_tensor_csv_static_type(out, *this); @@ -993,6 +1417,18 @@ void DeviceAllocation::write_tensor_csv( case library::NumericTypeID::kS64: write_tensor_csv_static_type(out, *this); break; + + case library::NumericTypeID::kB1: + write_tensor_csv_static_type(out, *this); + break; + + case library::NumericTypeID::kU2: + write_tensor_csv_static_type(out, *this); + break; + + case library::NumericTypeID::kU4: + write_tensor_csv_static_type(out, *this); + break; case library::NumericTypeID::kU8: write_tensor_csv_static_type(out, *this); @@ -1010,6 +1446,10 @@ void DeviceAllocation::write_tensor_csv( write_tensor_csv_static_type(out, *this); break; + case library::NumericTypeID::kCF16: + write_tensor_csv_static_type >(out, *this); + break; + case library::NumericTypeID::kCF32: write_tensor_csv_static_type >(out, *this); break; diff --git a/tools/profiler/src/device_allocation.h b/tools/profiler/src/device_allocation.h index f57cda1431..b7bb5ec729 100644 --- a/tools/profiler/src/device_allocation.h +++ b/tools/profiler/src/device_allocation.h @@ -51,6 +51,9 @@ class DeviceAllocation { /// Data type of contained elements library::NumericTypeID type_; + /// Gets the stride between elements + size_t batch_stride_; + /// Capacity in elements of device allocation size_t capacity_; @@ -66,6 +69,9 @@ class DeviceAllocation { /// Extent vector std::vector extent_; + /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory + int batch_count_; + /// Buffer holding TensorRef instance to recently allocated memory std::vector tensor_ref_buffer_; @@ -118,7 +124,8 @@ class DeviceAllocation { library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector()); + std::vector const &stride = std::vector(), + int batch_count = 1); ~DeviceAllocation(); @@ -132,7 +139,8 @@ class DeviceAllocation { library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector()); + std::vector const &stride = std::vector(), + int batch_count = 1); /// Returns a buffer owning the tensor reference std::vector &tensor_ref() { @@ -144,9 +152,12 @@ class DeviceAllocation { /// Data type of contained elements library::NumericTypeID type() const; - /// Pointer to device memory + /// Pointer to start of device memory allocation void *data() const; + /// Pointer to the first element of a batch + void *batch_data(int batch_idx) const; + /// Gets the layout type library::LayoutTypeID layout() const; @@ -156,6 +167,15 @@ class DeviceAllocation { /// Gets the extent vector std::vector const & extent() const; + /// Gets the number of adjacent tensors in memory + int batch_count() const; + + /// Gets the stride (in units of elements) beteween items + int64_t batch_stride() const; + + /// Gets the stride (in units of bytes) beteween items + int64_t batch_stride_bytes() const; + /// Capacity of allocation in number of elements size_t capacity() const; @@ -165,9 +185,15 @@ class DeviceAllocation { /// Initializes a device allocation to a random distribution using cuRAND void initialize_random_device(int seed, Distribution dist); - /// Initializes a device allocation to a random distribution using cuRAND + /// Initializes a host allocation to a random distribution using std::cout void initialize_random_host(int seed, Distribution dist); + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); + /// Copies from an equivalent-sized tensor in device memory void copy_from_device(void const *ptr); diff --git a/tools/profiler/src/device_context.cu b/tools/profiler/src/device_context.cu index f9cfe9ab58..a8bd4fa218 100644 --- a/tools/profiler/src/device_context.cu +++ b/tools/profiler/src/device_context.cu @@ -52,9 +52,10 @@ DeviceAllocation *DeviceContext::allocate_tensor( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride) { + std::vector const &stride, + int batch_count) { - device_memory_.emplace_back(type, layout_id, extent, stride); + device_memory_.emplace_back(type, layout_id, extent, stride, batch_count); DeviceAllocation *allocation = &device_memory_.back(); allocations_[name] = allocation; @@ -68,10 +69,11 @@ DeviceAllocation *DeviceContext::allocate_tensor( library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride) { + std::vector const &stride, + int batch_count) { DeviceAllocation *allocation = - allocate_tensor(name, type, layout_id, extent, stride); + allocate_tensor(name, type, layout_id, extent, stride, batch_count); if (options.initialization.enabled) { Distribution data_distribution = options.initialization.data_distribution; @@ -81,10 +83,22 @@ DeviceAllocation *DeviceContext::allocate_tensor( // change data distribution based on bit width switch(type) { case library::NumericTypeID::kB1: + data_distribution.set_uniform(0, 1, 0); + break; + case library::NumericTypeID::kS2: + data_distribution.set_uniform(-1, 1, 0); + break; + case library::NumericTypeID::kS4: + data_distribution.set_uniform(-2, 2, 0); + break; + case library::NumericTypeID::kU2: data_distribution.set_uniform(0, 2, 0); - break; + break; + case library::NumericTypeID::kU4: + data_distribution.set_uniform(0, 2, 0); + break; case library::NumericTypeID::kS8: - data_distribution.set_uniform(-2, 2, 0); + data_distribution.set_uniform(-3, 3, 0); break; case library::NumericTypeID::kU8: data_distribution.set_uniform(0, 4, 0); @@ -96,18 +110,50 @@ DeviceAllocation *DeviceContext::allocate_tensor( if (options.initialization.provider == library::Provider::kReferenceDevice) { allocation->initialize_random_device( options.initialization.seed, - data_distribution); + data_distribution); } else if (options.initialization.provider == library::Provider::kReferenceHost) { allocation->initialize_random_host( options.initialization.seed, - data_distribution); + data_distribution); } } return allocation; } +/// Allocates memory for sparse meta data +DeviceAllocation *DeviceContext::allocate_sparsemeta_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + library::NumericTypeID type_a, + std::vector const &extent, + std::vector const &stride, + int batch_count) { + + DeviceAllocation *allocation = + allocate_tensor(name, type, layout_id, extent, stride, batch_count); + + if (options.initialization.enabled) { + // TF32 has 4bit meta data. The rest has 2bit. + int MetaSizeInBits = (cutlass::library::sizeof_bits(type_a) == 32) ? 4 : 2; + + if (options.initialization.provider == library::Provider::kReferenceDevice) { + allocation->initialize_random_sparsemeta_device( + options.initialization.seed, + MetaSizeInBits); + } + else if (options.initialization.provider == library::Provider::kReferenceHost) { + allocation->initialize_random_sparsemeta_host( + options.initialization.seed, + MetaSizeInBits); + } + } + + return allocation; +} /// Clears named allocations (but does not necessarily free memory) void DeviceContext::clear() { allocations_.clear(); diff --git a/tools/profiler/src/device_context.h b/tools/profiler/src/device_context.h index aea872eff8..1633a2dd29 100644 --- a/tools/profiler/src/device_context.h +++ b/tools/profiler/src/device_context.h @@ -33,6 +33,7 @@ #include "cutlass/library/library.h" +#include "cutlass/library/util.h" #include "options.h" #include "device_allocation.h" @@ -76,7 +77,8 @@ class DeviceContext { library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector()); + std::vector const &stride = std::vector(), + int batch_count = 1); /// Allocates memory of a given type, capacity (elements), and name DeviceAllocation *allocate_tensor( @@ -85,7 +87,19 @@ class DeviceContext { library::NumericTypeID type, library::LayoutTypeID layout_id, std::vector const &extent, - std::vector const &stride = std::vector()); + std::vector const &stride = std::vector(), + int batch_count = 1); + + /// Allocates memory for sparse meta data + DeviceAllocation *allocate_sparsemeta_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + library::NumericTypeID type_a, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1); /// Clears named allocations (but does not necessarily free memory) void clear(); diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index f494eeee9f..cf7f8ff64c 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -37,6 +37,9 @@ #include "gemm_operation_profiler.h" #include "gpu_timer.h" +#include "cutlass/library/library.h" +#include "cutlass/library/handle.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { @@ -135,6 +138,8 @@ Status GemmOperationProfiler::GemmProblem::parse( library::GemmDescription const &operation_desc, ProblemSpace const &problem_space, ProblemSpace::Problem const &problem) { + + this->mode = library::GemmUniversalMode::kGemm; if (!arg_as_int(this->m, "m", problem_space, problem)) { // default value @@ -151,6 +156,7 @@ Status GemmOperationProfiler::GemmProblem::parse( this->k = 1024; } + this->mode = library::GemmUniversalMode::kGemm; if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { // default value this->split_k_slices = 1; @@ -160,6 +166,14 @@ Status GemmOperationProfiler::GemmProblem::parse( // default value this->batch_count = 1; } + else if (this->batch_count > 1) { + this->mode = library::GemmUniversalMode::kBatched; + } + + if (this->split_k_slices > 1 && this->batch_count > 1) { + // At least one of these must be one + return Status::kErrorInvalidProblem; + } if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { return Status::kErrorInvalidProblem; @@ -209,6 +223,48 @@ Status GemmOperationProfiler::GemmProblem::parse( return Status::kSuccess; } +/// Total number of bytes loaded +int64_t GemmOperationProfiler::GemmProblem::bytes(library::GemmDescription const &operation_desc) const { + // Input bytes read and Output bytes written for the gemm problem + int64_t bytes = + int64_t(library::sizeof_bits(operation_desc.A.element) * m / 8) * k + + int64_t(library::sizeof_bits(operation_desc.B.element) * n / 8) * k + + int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; + + // Set is_beta_zero true if beta is zero + bool is_beta_zero = std::all_of(beta.begin(), beta.end(), [](uint8_t i) { return i==0; }); + + // Output bytes read for the gemm problem for non-zero beta values + if (!is_beta_zero) { + bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * m / 8) * n; + } + + bytes *= batch_count; + + return bytes; +} + +/// Total number of flops computed +int64_t GemmOperationProfiler::GemmProblem::flops(library::GemmDescription const &operation_desc) const { + int64_t flops_ = (int64_t(m) * n * k + m * n) * 2 * batch_count; + + // complex-valued support + switch (operation_desc.tile_description.math_instruction.math_operation) { + case library::MathOperationID::kMultiplyAddComplex: + flops_ *= 4; + break; + + case library::MathOperationID::kMultiplyAddGaussianComplex: + flops_ *= 3; + break; + + default: break; + } + + return flops_; +} + + /// Initializes a performance result void GemmOperationProfiler::GemmProblem::initialize_result( PerformanceResult &result, @@ -266,6 +322,7 @@ Status GemmOperationProfiler::initialize_configuration( return status; } + gemm_workspace_.configuration.mode = problem_.mode; gemm_workspace_.configuration.problem_size.m() = int(problem_.m); gemm_workspace_.configuration.problem_size.n() = int(problem_.n); gemm_workspace_.configuration.problem_size.k() = int(problem_.k); @@ -273,8 +330,13 @@ Status GemmOperationProfiler::initialize_configuration( gemm_workspace_.configuration.ldb = problem_.ldb; gemm_workspace_.configuration.ldc = problem_.ldc; gemm_workspace_.configuration.ldd = problem_.ldc; - //gemm_workspace_.configuration.split_k_slices = int(problem_.split_k_slices); - gemm_workspace_.configuration.batch_count = int(problem_.split_k_slices); + + if (problem_.mode == library::GemmUniversalMode::kBatched) { + gemm_workspace_.configuration.batch_count = problem_.batch_count; + } + else { + gemm_workspace_.configuration.batch_count = problem_.split_k_slices; + } gemm_workspace_.arguments.A = nullptr; gemm_workspace_.arguments.B = nullptr; @@ -305,32 +367,10 @@ void GemmOperationProfiler::initialize_result_( OperationProfiler::initialize_result_(result, operation_desc, problem_space); - // Input bytes read and Output bytes written for the gemm problem - result.bytes = - int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.m / 8) * problem_.k + - int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.n / 8) * problem_.k + - int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n; - - // Set is_beta_zero true if beta is zero - bool is_beta_zero = std::all_of(problem_.beta.begin(), problem_.beta.end(), [](uint8_t i) { return i==0; }); - - // Output bytes read for the gemm problem for non-zero beta values - if (!is_beta_zero) { - result.bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n; - } - - result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n); + result.bytes = problem_.bytes(operation_desc); + result.flops = problem_.flops(operation_desc); result.runtime = 0; - // complex-valued support - switch (operation_desc.tile_description.math_instruction.math_operation) { - case library::MathOperationID::kMultiplyAddComplex: - result.flops *= 4; - break; - - default: break; - } - } /// Initializes workspace @@ -345,6 +385,21 @@ Status GemmOperationProfiler::initialize_workspace( library::GemmDescription const &operation_desc = static_cast(operation->description()); + // Compute the number of copies of the problem to avoid L2 camping. + if (!options.profiling.workspace_count) { + int64_t bytes = problem_.bytes(operation_desc); + if (bytes < 3 * int64_t(options.device.properties.l2CacheSize)) { + gemm_workspace_.problem_count = + 1 + int((3 * int64_t(options.device.properties.l2CacheSize)) / bytes); + } + else { + gemm_workspace_.problem_count = 1; + } + } + else { + gemm_workspace_.problem_count = options.profiling.workspace_count; + } + if (options.execution_mode != ExecutionMode::kDryRun) { gemm_workspace_.A = device_context.allocate_tensor( @@ -353,7 +408,8 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.A.element, operation_desc.A.layout, {int(problem_.m), int(problem_.k)}, - {int(problem_.lda)} + {int(problem_.lda)}, + problem_.batch_count * gemm_workspace_.problem_count ); gemm_workspace_.B = device_context.allocate_tensor( @@ -362,7 +418,8 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.B.element, operation_desc.B.layout, {int(problem_.k), int(problem_.n)}, - {int(problem_.ldb)} + {int(problem_.ldb)}, + problem_.batch_count * gemm_workspace_.problem_count ); gemm_workspace_.C = device_context.allocate_tensor( @@ -371,7 +428,8 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.C.element, operation_desc.C.layout, {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)} + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_.problem_count ); gemm_workspace_.Computed = device_context.allocate_tensor( @@ -379,7 +437,8 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.C.element, operation_desc.C.layout, {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)} + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_.problem_count ); gemm_workspace_.Reference = device_context.allocate_tensor( @@ -387,7 +446,8 @@ Status GemmOperationProfiler::initialize_workspace( operation_desc.C.element, operation_desc.C.layout, {int(problem_.m), int(problem_.n)}, - {int(problem_.ldc)} + {int(problem_.ldc)}, + problem_.batch_count * gemm_workspace_.problem_count ); gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); @@ -458,6 +518,10 @@ bool GemmOperationProfiler::verify_cutlass( gemm_workspace_.arguments.alpha = problem_.alpha.data(); gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); // // Run the CUTLASS operation @@ -512,6 +576,8 @@ bool GemmOperationProfiler::verify_cutlass( } } #endif // #if CUTLASS_ENABLE_CUBLAS + + verify_with_reference_(options, report, device_context, operation, problem_space, problem); // Update disposition to worst case verification outcome among all // verification providers which are supported @@ -561,7 +627,7 @@ bool GemmOperationProfiler::verify_with_cublas_( if (status != CUBLAS_STATUS_SUCCESS) { - results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; + results_.back().verification_map[library::Provider::kCUBLAS] = get_cutlass_disposition(status); return true; } @@ -589,9 +655,13 @@ bool GemmOperationProfiler::verify_with_cublas_( // Initialize structure containing GEMM arguments gemm_workspace_.arguments.A = gemm_workspace_.A->data(); + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); gemm_workspace_.arguments.B = gemm_workspace_.B->data(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); gemm_workspace_.arguments.C = gemm_workspace_.Reference->data(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Reference->batch_stride(); gemm_workspace_.arguments.D = gemm_workspace_.Reference->data(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Reference->batch_stride(); gemm_workspace_.arguments.alpha = problem_.alpha.data(); gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; @@ -615,7 +685,7 @@ bool GemmOperationProfiler::verify_with_cublas_( // Handle errors if (status != CUBLAS_STATUS_SUCCESS) { - results_.back().verification_map[library::Provider::kCUBLAS] = Disposition::kFailed; + results_.back().verification_map[library::Provider::kCUBLAS] = get_cutlass_disposition(status); return true; } @@ -626,7 +696,8 @@ bool GemmOperationProfiler::verify_with_cublas_( results_.back().verification_map[library::Provider::kCUBLAS] = compare_tensors( options, *gemm_workspace_.Computed, - *gemm_workspace_.Reference + *gemm_workspace_.Reference, + gemm_workspace_.Computed->batch_stride() ); // Save workspace if incorrect @@ -653,6 +724,150 @@ bool GemmOperationProfiler::verify_with_cublas_( ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Verifies CUTLASS against host and device references +bool GemmOperationProfiler::verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + library::GemmDescription const &gemm_desc = + static_cast(operation->description()); + + // + // Initialize state + // + + library::Provider references[] = { + library::Provider::kReferenceDevice, + library::Provider::kReferenceHost + }; + + for (auto provider : references) { + + // Skip providers that are not enabled + if (!options.verification.provider_enabled(provider)) { + continue; + } + + void *ptr_A = gemm_workspace_.A->data(); + void *ptr_B = gemm_workspace_.B->data(); + void *ptr_C = gemm_workspace_.C->data(); + void *ptr_D = gemm_workspace_.Reference->data(); + + // To support the host-side reference, conditionally allocate and + // copy tensors to host memory. + std::vector host_data_A; + std::vector host_data_B; + std::vector host_data_C; + std::vector host_data_D; + + if (provider == library::Provider::kReferenceHost) { + + host_data_A.resize(gemm_workspace_.A->bytes()); + ptr_A = host_data_A.data(); + gemm_workspace_.A->copy_to_host(ptr_A); + + host_data_B.resize(gemm_workspace_.B->bytes()); + ptr_B = host_data_B.data(); + gemm_workspace_.B->copy_to_host(ptr_B); + + host_data_C.resize(gemm_workspace_.C->bytes()); + ptr_C = host_data_C.data(); + gemm_workspace_.C->copy_to_host(ptr_C); + + host_data_D.resize(gemm_workspace_.Reference->bytes()); + ptr_D = host_data_D.data(); + } + + // + // Launch + // + + library::Handle handle; + + handle.set_provider(provider); + + Status status = handle.gemm_universal( + library::GemmUniversalMode::kGemm, + gemm_workspace_.configuration.problem_size.m(), + gemm_workspace_.configuration.problem_size.n(), + gemm_workspace_.configuration.problem_size.k(), + gemm_desc.tile_description.math_instruction.element_accumulator, + gemm_desc.element_epilogue, + + problem_.alpha.data(), + + gemm_desc.A.element, + gemm_desc.A.layout, + gemm_desc.transform_A, + ptr_A, + int(gemm_workspace_.configuration.lda), + + gemm_desc.B.element, + gemm_desc.B.layout, + gemm_desc.transform_B, + ptr_B, + int(gemm_workspace_.configuration.ldb), + + problem_.beta.data(), + + gemm_desc.C.element, + ptr_C, + int(gemm_workspace_.configuration.ldc), + + ptr_D, + int(gemm_workspace_.configuration.ldd), + + gemm_workspace_.configuration.batch_count, + gemm_workspace_.A->batch_stride(), + gemm_workspace_.B->batch_stride(), + gemm_workspace_.C->batch_stride(), + gemm_workspace_.Reference->batch_stride() + ); + + if (status != Status::kSuccess) { + results_.back().verification_map[provider] = Disposition::kNotRun; + return true; + } + + results_.back().status = status; + + if (provider == library::Provider::kReferenceHost) { + gemm_workspace_.Reference->copy_from_host(ptr_D); + } + + // + // Verify results + // + + results_.back().verification_map[provider] = compare_tensors( + options, + *gemm_workspace_.Computed, + *gemm_workspace_.Reference, + gemm_workspace_.Computed->batch_stride() + ); + + // Save workspace if incorrect + if (options.verification.save_workspace == SaveWorkspace::kIncorrect && + results_.back().verification_map[provider] == Disposition::kIncorrect) { + + save_workspace( + device_context, + options, + gemm_desc, + library::Provider::kCUTLASS, + provider); + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Measures performance results bool GemmOperationProfiler::profile( Options const &options, @@ -672,6 +887,10 @@ bool GemmOperationProfiler::profile( gemm_workspace_.arguments.alpha = problem_.alpha.data(); gemm_workspace_.arguments.beta = problem_.beta.data(); gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + gemm_workspace_.arguments.batch_stride_A = gemm_workspace_.A->batch_stride(); + gemm_workspace_.arguments.batch_stride_B = gemm_workspace_.B->batch_stride(); + gemm_workspace_.arguments.batch_stride_C = gemm_workspace_.C->batch_stride(); + gemm_workspace_.arguments.batch_stride_D = gemm_workspace_.Computed->batch_stride(); results_.back().status = profile_cutlass_( results_.back().runtime, @@ -687,6 +906,100 @@ bool GemmOperationProfiler::profile( ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Method to profile a CUTLASS Operation +Status GemmOperationProfiler::profile_cutlass_( + double &runtime, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace) { + + GpuTimer timer; + + // + // Optional sleep to limit power consumption and thermals + // + + sleep(options.profiling.sleep_duration); + + // + // Warmup loop + // + + Status status; + + for (int iteration = 0; iteration < options.profiling.warmup_iterations; ++iteration) { + + int problem_idx = (iteration % gemm_workspace_.problem_count) * problem_.batch_count; + + gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); + gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); + gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); + + // Execute the CUTLASS operation + status = operation->run( + &gemm_workspace_.arguments, + host_workspace, + device_workspace); + + if (status != Status::kSuccess) { + return status; + } + } + + // + // Initialize GPU timer + // + + timer.start(); + + // + // Profiling loop + // + + int Iterations = options.profiling.iterations; + + int iteration = 0; + for (; iteration < Iterations; ++iteration) { + + // Iterate over copies of the problem in memory + int workspace_idx = options.profiling.warmup_iterations + iteration; + int problem_idx = (workspace_idx % gemm_workspace_.problem_count) * problem_.batch_count; + + gemm_workspace_.arguments.A = gemm_workspace_.A->batch_data(problem_idx); + gemm_workspace_.arguments.B = gemm_workspace_.B->batch_data(problem_idx); + gemm_workspace_.arguments.C = gemm_workspace_.C->batch_data(problem_idx); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->batch_data(problem_idx); + + status = operation->run( + arguments, + host_workspace, + device_workspace); + + if (status != Status::kSuccess) { + return status; + } + } + + // + // Wait for completion + // + + timer.stop_and_wait(); + + // + // Update performance result + // + + runtime = timer.duration(iteration); + + return status; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace profiler } // namespace cutlass diff --git a/tools/profiler/src/gemm_operation_profiler.h b/tools/profiler/src/gemm_operation_profiler.h index e4d23212e1..1c6c5e7ceb 100644 --- a/tools/profiler/src/gemm_operation_profiler.h +++ b/tools/profiler/src/gemm_operation_profiler.h @@ -59,6 +59,8 @@ class GemmOperationProfiler : public OperationProfiler { /// Problem structure obtained from problem space struct GemmProblem { + + cutlass::library::GemmUniversalMode mode; int64_t m; int64_t n; int64_t k; @@ -67,14 +69,15 @@ class GemmOperationProfiler : public OperationProfiler { int64_t ldc; std::vector alpha; std::vector beta; - int64_t split_k_slices; - int64_t batch_count; + int split_k_slices; + int batch_count; // // Methods // GemmProblem(): + mode(library::GemmUniversalMode::kGemm), m(16), n(16), k(16), lda(0), ldb(0), ldc(0), split_k_slices(1), batch_count(1) { } /// Parses the problem @@ -83,6 +86,12 @@ class GemmOperationProfiler : public OperationProfiler { ProblemSpace const &problem_space, ProblemSpace::Problem const &problem); + /// Total number of bytes loaded + int64_t bytes(library::GemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GemmDescription const &operation_desc) const; + /// Initializes a performance result void initialize_result( PerformanceResult &result, @@ -99,6 +108,10 @@ class GemmOperationProfiler : public OperationProfiler { DeviceAllocation *Computed; DeviceAllocation *Reference; + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + library::GemmUniversalConfiguration configuration; library::GemmUniversalArguments arguments; @@ -113,7 +126,7 @@ class GemmOperationProfiler : public OperationProfiler { // GemmWorkspace(): - A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr), problem_count(1) { } }; protected: @@ -200,6 +213,24 @@ class GemmOperationProfiler : public OperationProfiler { ProblemSpace const &problem_space, ProblemSpace::Problem const &problem); + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + double &runtime, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/src/operation_profiler.cu b/tools/profiler/src/operation_profiler.cu index 754118a738..2bbf2eeb11 100644 --- a/tools/profiler/src/operation_profiler.cu +++ b/tools/profiler/src/operation_profiler.cu @@ -268,6 +268,7 @@ int OperationProfiler::profile_all( std::string operation_name(operation->description().name); + // Filter kernels by name bool filtered_by_name = options.operation_names.empty(); if (!filtered_by_name) { @@ -279,6 +280,13 @@ int OperationProfiler::profile_all( } } + for (auto const & op_name : options.excluded_operation_names) { + if (find_string_matches_(op_name, operation_name)) { + filtered_by_name = false; + break; + } + } + if (!filtered_by_name || !satisfies(operation->description(), problem_space, problem)) { continue; } @@ -410,7 +418,8 @@ void OperationProfiler::sleep(int sleep_duration) { Disposition OperationProfiler::compare_tensors( Options const &options, DeviceAllocation &experimental, - DeviceAllocation &reference) { + DeviceAllocation &reference, + int64_t count) { if (experimental.type() != reference.type()) { return Disposition::kIncorrect; @@ -418,6 +427,10 @@ Disposition OperationProfiler::compare_tensors( bool passed = false; + if (count == 0) { + count = reference.capacity(); + } + if (options.verification.epsilon == 0) { // bit-level equality @@ -425,7 +438,7 @@ Disposition OperationProfiler::compare_tensors( experimental.type(), experimental.data(), reference.data(), - experimental.capacity()); + count); } else { @@ -434,7 +447,7 @@ Disposition OperationProfiler::compare_tensors( experimental.type(), experimental.data(), reference.data(), - experimental.capacity(), + count, options.verification.epsilon, options.verification.nonzero_floor); } @@ -483,7 +496,7 @@ Status OperationProfiler::profile_cutlass_( double &runtime, Options const &options, library::Operation const *operation, - void const *arguments, + void *arguments, void *host_workspace, void *device_workspace) { diff --git a/tools/profiler/src/operation_profiler.h b/tools/profiler/src/operation_profiler.h index c7e20f36f7..731554b6f2 100644 --- a/tools/profiler/src/operation_profiler.h +++ b/tools/profiler/src/operation_profiler.h @@ -189,7 +189,8 @@ class OperationProfiler { static Disposition compare_tensors( Options const &options, DeviceAllocation &experimental, - DeviceAllocation &reference); + DeviceAllocation &reference, + int64_t count = 0); static void save_workspace( DeviceContext &device_context, @@ -225,7 +226,7 @@ class OperationProfiler { double &runtime, Options const &options, library::Operation const *operation, - void const *arguments, + void *arguments, void *host_workspace, void *device_workspace); diff --git a/tools/profiler/src/options.cu b/tools/profiler/src/options.cu index 5f62a81e73..e2d3e131f0 100644 --- a/tools/profiler/src/options.cu +++ b/tools/profiler/src/options.cu @@ -70,6 +70,16 @@ Options::Device::Device(cutlass::CommandLine const &cmdline) { properties.major = cc / 10; properties.minor = cc % 10; } + + // Permit overriding the L2 cache capacity + if (cmdline.check_cmd_line_flag("llc-capacity")) { + int llc_capacity = 0; + cmdline.get_cmd_line_argument("llc-capacity", llc_capacity, 0); + + if (llc_capacity >= 0) { + properties.l2CacheSize = (llc_capacity << 10); + } + } } @@ -107,7 +117,12 @@ void Options::Device::print_usage(std::ostream &out) const { out << " --compute-capability= " - << " Override the compute capability.\n\n"; + << " Override the compute capability.\n\n" + + << " --llc-capacity= " + << " Capacity of last-level cache in kilobytes. If this is non-zero," << end_of_line + << " profiling phases cycle through different input tensors to induce" << end_of_line + << " capacity misses in the L2.\n\n"; } @@ -189,6 +204,7 @@ Options::Initialization::Initialization(cutlass::CommandLine const &cmdline) { // set uniform data distribution with range [-4, 4] data_distribution.set_uniform(-4, 4, 0); } + } @@ -364,7 +380,8 @@ void Options::Library::print_options(std::ostream &out, int indent) const { ///////////////////////////////////////////////////////////////////////////////////////////////// Options::Profiling::Profiling(cutlass::CommandLine const &cmdline) { - + + cmdline.get_cmd_line_argument("workspace-count", workspace_count, 0); cmdline.get_cmd_line_argument("warmup-iterations", warmup_iterations, 10); cmdline.get_cmd_line_argument("profiling-iterations", iterations, 100); cmdline.get_cmd_line_argument("sleep-duration", sleep_duration, 50); @@ -391,6 +408,11 @@ void Options::Profiling::print_usage(std::ostream &out) const { out << "Profiling:\n" + << " --workspace-count= " + << " Number of discrete workspaces maintained to avoid cache-resident " << end_of_line + << " If zero (default), the amount is chosen for each workload based on " << end_of_line + << " capacity of the last-level cache.\n\n" + << " --profiling-iterations= " << " Number of iterations to profile each kernel. If zero, kernels" << end_of_line << " are launched up to the profiling duration.\n\n" @@ -672,6 +694,10 @@ Options::Options(cutlass::CommandLine const &cmdline): cmdline.get_cmd_line_arguments("kernels", operation_names); } + if (cmdline.check_cmd_line_flag("ignore-kernels")) { + cmdline.get_cmd_line_arguments("ignore-kernels", excluded_operation_names); + } + // Prevent launches on the device for anything other than CUTLASS operation if (execution_mode == ExecutionMode::kTrace) { initialization.provider = library::Provider::kReferenceHost; @@ -706,6 +732,9 @@ void Options::print_usage(std::ostream &out) const { << " Filter operations by kernel names. For example, call all kernels with" << end_of_line << " (\"s1688\" and \"nt\") or (\"s844\" and \"tn\" and \"align8\") in their" << end_of_line << " operation name using --kernels=\"s1688*nt, s884*tn*align8\"\n\n" + + << " --ignore-kernels= " + << " Excludes kernels whose names match anything in this list.\n\n" ; // diff --git a/tools/profiler/src/options.h b/tools/profiler/src/options.h index f4b5f0a130..48463efa50 100644 --- a/tools/profiler/src/options.h +++ b/tools/profiler/src/options.h @@ -175,6 +175,9 @@ class Options { /// Options related to profiling struct Profiling { + /// Number of workspaces to rotate through to avoid cache-resident working sets + int workspace_count; + /// Number of iterations to warmup each kernel prior to profiling int warmup_iterations; @@ -273,6 +276,10 @@ class Options { /// Vector of operation name substrings std::vector operation_names; + + /// Vector of operation name substrings + std::vector excluded_operation_names; + // // Detailed configuration options diff --git a/tools/profiler/src/performance_report.cpp b/tools/profiler/src/performance_report.cpp index 0ab7044929..07a7edc955 100644 --- a/tools/profiler/src/performance_report.cpp +++ b/tools/profiler/src/performance_report.cpp @@ -263,6 +263,10 @@ std::ostream & PerformanceReport::print_csv_header_( << ",OperationKind,Operation,Disposition,Status"; for (auto const &arg_name : argument_names_) { + // Operand E is internal to the sparse kernel + if (arg_name.compare("E") == 0) + continue; + out << "," << arg_name; } diff --git a/tools/profiler/src/problem_space.cpp b/tools/profiler/src/problem_space.cpp index adede0ea1f..e69b0110e9 100644 --- a/tools/profiler/src/problem_space.cpp +++ b/tools/profiler/src/problem_space.cpp @@ -577,19 +577,65 @@ void ProblemSpace::parse_(KernelArgument *arg, CommandLine const &cmdline) { std::vector > tokens; cmdline.get_cmd_line_argument_ranges(alias.c_str(), tokens); - for (auto const &range_tokens : tokens) { + for (auto &range_tokens : tokens) { if (!range_tokens.empty()) { - Range range(lexical_cast(range_tokens.front())); - if (range_tokens.size() > 1) { - range.last = lexical_cast(range_tokens.at(1)); - } + Range range; - if (range_tokens.size() > 2) { - range.increment = lexical_cast(range_tokens.at(2)); + if (range_tokens.front() == "rand") { + range.mode = Range::Mode::kRandom; + } + else if (range_tokens.front() == "randlg2") { + range.mode = Range::Mode::kRandomLog2; } + switch (range.mode) { + case Range::Mode::kSequence: + { + range.first = lexical_cast(range_tokens.front()); + + if (range_tokens.size() > 1) { + range.last = lexical_cast(range_tokens.at(1)); + } + else { + range.last = range.first; + } + + if (range_tokens.size() > 2) { + range.increment = lexical_cast(range_tokens.at(2)); + } + else { + range.increment = 1; + } + } + break; + case Range::Mode::kRandom: // fall-through + case Range::Mode::kRandomLog2: + { + if (range_tokens.size() < 4) { + throw std::runtime_error( + "Range of mode 'rand' must have four tokens showing " + "the minimum, maximum, and number of iterations. For example, " + "rand:16:128:1000"); + } + + range.minimum = lexical_cast(range_tokens.at(1)); + range.maximum = lexical_cast(range_tokens.at(2)); + range.first = 1; + range.last = lexical_cast(range_tokens.at(3)); + range.increment = 1; + + if (range_tokens.size() > 4) { + range.divisible = lexical_cast(range_tokens.at(4)); + } + } + break; + default: + throw std::runtime_error("Unsupported range mode."); + break; + } + integer->ranges.push_back(range); } } @@ -713,6 +759,30 @@ bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr) { return false; } +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr) { + int64_t value64; + bool obtained = arg_as_int(value64, value_ptr); + if (obtained) { + int_value = int(value64); + return true; + } + return false; +} + +/// Lexically casts an argument to an int +bool arg_as_int( + int &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + size_t idx = problem_space.argument_index(name); + KernelArgument::Value const *value_ptr = problem.at(idx).get(); + + return arg_as_int(int_value, value_ptr); +} + /// Lexically casts an argument to an int64 bool arg_as_int( int64_t &int_value, diff --git a/tools/profiler/src/problem_space.h b/tools/profiler/src/problem_space.h index 77a79ca2a6..8a9ee4f2e8 100644 --- a/tools/profiler/src/problem_space.h +++ b/tools/profiler/src/problem_space.h @@ -49,6 +49,7 @@ #include #include #include +#include // CUTLASS Utility includes #include "cutlass/util/command_line.h" @@ -307,10 +308,18 @@ struct Range { // Type definitions // + enum class Mode { + kSequence, + kRandom, + kRandomLog2, + kInvalid + }; + struct Iterator { int64_t value; int64_t increment; + Range const *range; // // Methods @@ -318,9 +327,10 @@ struct Range { Iterator( int64_t value_ = 0, - int64_t increment_ = 1 + int64_t increment_ = 1, + Range const *range_ = nullptr ): - value(value_), increment(increment_) { } + value(value_), increment(increment_), range(range_) { } Iterator & operator++() { value += increment; @@ -341,7 +351,50 @@ struct Range { return !(*this == it); } + static int64_t round(int64_t value, int64_t divisible) { + int64_t rem = (value % divisible); + + // Round either up or down + if (rem > divisible / 2) { + value += (divisible - rem); + } + else { + value -= rem; + } + + return value; + } + int64_t at() const { + if (!range) { + return value; + } + + switch (range->mode) { + case Mode::kSequence: return value; + + case Mode::kRandom: { + double rnd = double(range->minimum) + + double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); + + int64_t value = int64_t(rnd); + + return round(value, range->divisible); + } + break; + + case Mode::kRandomLog2: { + double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); + double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); + double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); + + int64_t value = int64_t(std::pow(2.0, rnd)); + + return round(value, range->divisible); + } + break; + default: break; + } return value; } @@ -357,20 +410,29 @@ struct Range { int64_t first; ///< first element in range int64_t last; ///< last element in range int64_t increment; ///< additive increment between values + + Mode mode; ///< mode selection enables alternative values + int64_t minimum; ///< minimum value to return + int64_t maximum; ///< maximum value to return + int64_t divisible; ///< rounds value down to an integer multiple of this value // // Methods // /// Default constructor - range acts as a scalar - Range(int64_t first_ = 0): first(first_), last(first_), increment(1) { } + Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } /// Range acts as a range Range( int64_t first_, int64_t last_, - int64_t increment_ = 1 - ): first(first_), last(last_), increment(increment_) { + int64_t increment_ = 1, + Mode mode_ = Mode::kSequence, + int64_t minimum_ = 0, + int64_t maximum_ = 0, + int64_t divisible_ = 1 + ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { // Helpers to avoid constructing invalid ranges if (increment > 0) { @@ -389,14 +451,29 @@ struct Range { } } + /// Helper to construct a sequence range + static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { + return Range(first_, last_, increment_, Mode::kSequence); + } + + /// Helper to construct a range that is a random distribution + static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); + } + + /// Helper to construct a range that is a random distribution over a log scale + static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); + } + /// Returns an iterator to the first element within the range Iterator begin() const { - return Iterator(first, increment); + return Iterator(first, increment, this); } /// Returns an iterator to the first element *after* the range Iterator end() const { - return Iterator(first + ((last - first)/increment + 1) * increment, increment); + return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); } }; @@ -770,9 +847,19 @@ class ProblemSpace { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Lexically casts an argument to an int if it is defined. Returns true if not null. +bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); + /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + /// Lexically casts an argument to an int64 if it is defined. Returns true if not null. bool arg_as_int( int64_t &int_value, diff --git a/tools/profiler/src/sparse_gemm_operation_profiler.cu b/tools/profiler/src/sparse_gemm_operation_profiler.cu new file mode 100644 index 0000000000..702b79bb6c --- /dev/null +++ b/tools/profiler/src/sparse_gemm_operation_profiler.cu @@ -0,0 +1,560 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment + +*/ + +#include +#include +#include +#include + +#include "cublas_helpers.h" +#include "sparse_gemm_operation_profiler.h" +#include "gpu_timer.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Ctor +SparseGemmOperationProfiler::SparseGemmOperationProfiler(Options const &options): + OperationProfiler( + options, + library::OperationKind::kSparseGemm, + { + {ArgumentTypeID::kEnumerated, {"gemm_kind"}, "Variant of GEMM (e.g. gemm, planar complex, batched, ...)"}, + {ArgumentTypeID::kInteger, {"m", "problem-size::m"}, "M dimension of the GEMM problem space"}, + {ArgumentTypeID::kInteger, {"n", "problem-size::n"}, "N dimension of the GEMM problem space"}, + {ArgumentTypeID::kInteger, {"k", "problem-size::k"}, "K dimension of the GEMM problem space"}, + {ArgumentTypeID::kTensor, {"A"}, "Tensor storing the A operand"}, + {ArgumentTypeID::kTensor, {"B"}, "Tensor storing the B operand"}, + {ArgumentTypeID::kTensor, {"C"}, "Tensor storing the C operand"}, + {ArgumentTypeID::kTensor, {"E"}, "Tensor storing the E operand"}, + {ArgumentTypeID::kScalar, {"alpha", "epilogue::alpha"}, "Epilogue scalar alpha"}, + {ArgumentTypeID::kScalar, {"beta", "epilogue::beta"}, "Epilogue scalar beta"}, + {ArgumentTypeID::kInteger, {"split_k_slices"}, "Number of partitions of K dimension"}, + {ArgumentTypeID::kInteger, {"batch_count"}, "Number of GEMMs computed in one batch"}, + } + ) { + + description_ = " Structured sparse GEMM. D = alpha * A*B + beta * C"; +} + +/// Destructor +SparseGemmOperationProfiler::~SparseGemmOperationProfiler() { + +} + +/// Prints usage statement for the math function +void SparseGemmOperationProfiler::print_usage(std::ostream &out) const { + out << "Sparse GEMM" << "\n\n"; + + OperationProfiler::print_usage(out); +} + +/// Prints examples +void SparseGemmOperationProfiler::print_examples(std::ostream &out) const { + + out << "\nExamples:\n\n" + << "Profile a particular problem size:\n" + << " $ cutlass_profiler --operation=SparseGemm --m=1024 --n=1024 --k=128\n\n" + + << "Schmoo over problem size and beta:\n" + << " $ cutlass_profiler --operation=SparseGemm --m=1024:4096:256 --n=1024:4096:256 --k=128:8192:128 --beta=0,1,2.5\n\n" + + << "Schmoo over accumulator types:\n" + << " $ cutlass_profiler --operation=SparseGemm --accumulator-type=f16,f32\n\n" + + << "Run when A is f16 with column-major and B is any datatype with row-major (For column major, use column, col, or n. For row major use, row or t):\n" + << " $ cutlass_profiler --operation=SparseGemm --A=f16:column --B=*:row\n\n" + + << "Using various input value distribution:\n" + << " $ cutlass_profiler --operation=SparseGemm --dist=uniform,min:0,max:3\n" + << " $ cutlass_profiler --operation=SparseGemm --dist=gaussian,mean:0,stddev:3\n" + << " $ cutlass_profiler --operation=SparseGemm --dist=sequential,start:0,delta:1\n\n" + + << "Run a kernel with cta tile size of 256x128x32 and save workspace if results are incorrect (note that --cta-tile::k=32 is default cta-tile size):\n" + << " $ cutlass_profiler --operation=SparseGemm --cta_m=256 --cta_n=128 --cta_k=32 --save-workspace=incorrect\n\n" + + << "Test your changes to gemm kernels with a quick functional test and save results in functional-test.csv:\n" + << " $ cutlass_profiler --operation=SparseGemm \\ \n" + << " --m=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" + << " --n=8,56,120,136,256,264,512,520,1024,1032,4096,8192,16384 \\ \n" + << " --k=8,16,32,64,128,256,288,384,504,512,520 \\ \n" + << " --beta=0,1,2 --profiling-iterations=1 \\ \n" + << " --providers=cutlass --output=functional-test.csv\n\n"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Status SparseGemmOperationProfiler::SparseGemmProblem::parse( + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + if (!arg_as_int(this->m, "m", problem_space, problem)) { + // default value + this->m = 1024; + } + + if (!arg_as_int(this->n, "n", problem_space, problem)) { + // default value + this->n = 1024; + } + + if (!arg_as_int(this->k, "k", problem_space, problem)) { + // default value + this->k = 1024; + } + + if (!arg_as_int(this->split_k_slices, "split_k_slices", problem_space, problem)) { + // default value + this->split_k_slices = 1; + } + + if (!arg_as_int(this->batch_count, "batch_count", problem_space, problem)) { + // default value + this->batch_count = 1; + } + + if (!tensor_description_satisfies(operation_desc.A, "A", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.B, "B", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.C, "C", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!tensor_description_satisfies(operation_desc.E, "E", problem_space, problem)) { + return Status::kErrorInvalidProblem; + } + + if (!arg_as_scalar( + this->alpha, + operation_desc.element_epilogue, + "alpha", + problem_space, + problem)) { + + if (!cast_from_double(this->alpha, operation_desc.element_epilogue, 1)) { + return Status::kErrorInternal; + } + } + + if (!arg_as_scalar( + this->beta, + operation_desc.element_epilogue, + "beta", + problem_space, + problem)) { + + if (!cast_from_double(this->beta, operation_desc.element_epilogue, 0)) { + return Status::kErrorInternal; + } + } + + this->elements_per_128b = + 128 / library::sizeof_bits(operation_desc.A.element); + + this->lda = DeviceAllocation::get_packed_layout( + operation_desc.A.layout, + {int(this->m), int(this->k) / int(this->sparse)}) + .front(); + + this->ldb = DeviceAllocation::get_packed_layout( + operation_desc.B.layout, {int(this->k), int(this->n)}).front(); + + this->ldc = DeviceAllocation::get_packed_layout( + operation_desc.C.layout, {int(this->m), int(this->n)}).front(); + + this->lde = + DeviceAllocation::get_packed_layout( + operation_desc.E.layout, + {int(this->m), int(this->k / this->sparse / this->elements_per_128b)}) + .front(); + + return Status::kSuccess; +} + +/// Initializes a performance result +void SparseGemmOperationProfiler::SparseGemmProblem::initialize_result( + PerformanceResult &result, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space) { + + result.arguments.resize(problem_space.rank()); + + set_argument(result, "gemm_kind", problem_space, library::to_string(operation_desc.gemm_kind)); + + set_argument(result, "A", problem_space, + std::string(library::to_string(operation_desc.A.element)) + ":" + library::to_string(operation_desc.A.layout)); + + set_argument(result, "B", problem_space, + std::string(library::to_string(operation_desc.B.element)) + ":" + library::to_string(operation_desc.B.layout)); + + set_argument(result, "C", problem_space, + std::string(library::to_string(operation_desc.C.element)) + ":" + library::to_string(operation_desc.C.layout)); + + set_argument(result, "m", problem_space, m); + set_argument(result, "n", problem_space, n); + set_argument(result, "k", problem_space, k); + + set_argument(result, "split_k_slices", problem_space, split_k_slices); + set_argument(result, "batch_count", problem_space, batch_count); + + set_argument(result, "alpha", problem_space, + library::lexical_cast(alpha, operation_desc.element_epilogue)); + + set_argument(result, "beta", problem_space, + library::lexical_cast(beta, operation_desc.element_epilogue)); +} + +/// Extracts the problem dimensions +Status SparseGemmOperationProfiler::initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + library::SparseGemmDescription const &operation_desc = + static_cast(operation->description()); + + if (operation_desc.gemm_kind != library::GemmKind::kSparse) { + return Status::kErrorInvalidProblem; + } + + Status status = problem_.parse(operation_desc, problem_space, problem); + + if (status != Status::kSuccess) { + return status; + } + + gemm_workspace_.configuration.problem_size.m() = int(problem_.m); + gemm_workspace_.configuration.problem_size.n() = int(problem_.n); + gemm_workspace_.configuration.problem_size.k() = int(problem_.k); + gemm_workspace_.configuration.lda = problem_.lda; + gemm_workspace_.configuration.ldb = problem_.ldb; + gemm_workspace_.configuration.ldc = problem_.ldc; + gemm_workspace_.configuration.ldd = problem_.ldc; + gemm_workspace_.configuration.lde = problem_.lde; + + gemm_workspace_.arguments.A = nullptr; + gemm_workspace_.arguments.B = nullptr; + gemm_workspace_.arguments.C = nullptr; + gemm_workspace_.arguments.D = nullptr; + gemm_workspace_.arguments.E = nullptr; + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + + initialize_result_(this->model_result_, options, operation_desc, problem_space); + + return operation->can_implement(&gemm_workspace_.configuration, &gemm_workspace_.arguments); +} + +/// Initializes the performance result +void SparseGemmOperationProfiler::initialize_result_( + PerformanceResult &result, + Options const &options, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space) { + + result.provider = library::Provider::kCUTLASS; + result.disposition = Disposition::kNotRun; + result.status = Status::kSuccess; + result.operation_name = operation_desc.name; + + problem_.initialize_result(result, operation_desc, problem_space); + + OperationProfiler::initialize_result_(result, operation_desc, problem_space); + + // Input bytes read and Output bytes written for the gemm problem + result.bytes = + int64_t(library::sizeof_bits(operation_desc.A.element) * problem_.m / 8) * + problem_.k / problem_.sparse + + int64_t(library::sizeof_bits(operation_desc.B.element) * problem_.n / 8) * + problem_.k + + int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * + problem_.n + + int64_t(library::sizeof_bits(operation_desc.E.element) * problem_.m / 8) * + problem_.k / problem_.sparse / problem_.elements_per_128b; + + // Set is_beta_zero true if beta is zero + bool is_beta_zero = std::all_of(problem_.beta.begin(), problem_.beta.end(), [](uint8_t i) { return i==0; }); + + // Output bytes read for the gemm problem for non-zero beta values + if (!is_beta_zero) { + result.bytes += int64_t(library::sizeof_bits(operation_desc.C.element) * problem_.m / 8) * problem_.n; + } + + result.flops = 2 * (problem_.m * problem_.n * problem_.k + problem_.m * problem_.n); + result.runtime = 0; + +} + +/// Initializes workspace +Status SparseGemmOperationProfiler::initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + library::SparseGemmDescription const &operation_desc = + static_cast(operation->description()); + + if (options.execution_mode != ExecutionMode::kDryRun) { + + gemm_workspace_.A = device_context.allocate_tensor( + options, + "A", + operation_desc.A.element, + operation_desc.A.layout, + {int(problem_.m), int(problem_.k) / int(problem_.sparse)}, + {int(problem_.lda)} + ); + + gemm_workspace_.B = device_context.allocate_tensor( + options, + "B", + operation_desc.B.element, + operation_desc.B.layout, + {int(problem_.k), int(problem_.n)}, + {int(problem_.ldb)} + ); + + gemm_workspace_.C = device_context.allocate_tensor( + options, + "C", + operation_desc.C.element, + operation_desc.C.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)} + ); + + gemm_workspace_.Computed = device_context.allocate_tensor( + "D", + operation_desc.C.element, + operation_desc.C.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)} + ); + + gemm_workspace_.E = device_context.allocate_sparsemeta_tensor( + options, + "E", + operation_desc.E.element, + operation_desc.E.layout, + operation_desc.A.element, + {int(problem_.m), int(problem_.k) / int(problem_.sparse) / int(problem_.elements_per_128b)}, + {int(problem_.lde)} + ); + + gemm_workspace_.Reference = device_context.allocate_tensor( + "Reference", + operation_desc.C.element, + operation_desc.C.layout, + {int(problem_.m), int(problem_.n)}, + {int(problem_.ldc)} + ); + + gemm_workspace_.Reference->copy_from_device(gemm_workspace_.C->data()); + } + + // + // Initialize the CUTLASS operation + // + + Status status = Status::kSuccess; + + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + + if (options.execution_mode != ExecutionMode::kDryRun) { + + uint64_t workspace_size = operation->get_host_workspace_size(&gemm_workspace_.configuration); + gemm_workspace_.host_workspace.resize(workspace_size, 0); + + workspace_size = operation->get_device_workspace_size(&gemm_workspace_.configuration); + gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + + status = operation->initialize( + &gemm_workspace_.configuration, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data()); + } + + // + // If CUTLASS is enabled, generate a result for it + // + + results_.push_back(model_result_); + results_.back().provider = library::Provider::kCUTLASS; + results_.back().op_kind = library::OperationKind::kSparseGemm; + results_.back().disposition = Disposition::kNotRun; + + for(auto &verification_provider : options.verification.providers) { + results_.back().verification_map[verification_provider] = Disposition::kNotRun; + } + } + + return status; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Verifies CUTLASS against references +bool SparseGemmOperationProfiler::verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + if (!options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + return true; + } + + if (options.execution_mode == ExecutionMode::kDryRun) { + return true; + } + + // Initialize structure containing GEMM arguments + gemm_workspace_.arguments.A = gemm_workspace_.A->data(); + gemm_workspace_.arguments.B = gemm_workspace_.B->data(); + gemm_workspace_.arguments.C = gemm_workspace_.C->data(); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); + gemm_workspace_.arguments.E = gemm_workspace_.E->data(); + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + + // + // Run the CUTLASS operation + // + + results_.back().status = operation->run( + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data()); + + if (results_.back().status != Status::kSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + results_.back().disposition = Disposition::kFailed; + return false; + } + + // CUTLASS op ran the but not yet verified against any verification provider + results_.back().disposition = Disposition::kNotVerified; + + // + // Run verification providers + // + + if (options.verification.enabled) { + + // Update disposition to worst case verification outcome among all + // verification providers which are supported + bool is_any_verification_run_passed = false; + + for(auto &m : results_.back().verification_map) { + if(m.second == Disposition::kFailed || m.second == Disposition::kIncorrect) { + results_.back().disposition = m.second; + return true; + } + if(!is_any_verification_run_passed && m.second == Disposition::kPassed) { + is_any_verification_run_passed = true; + } + } + + if(is_any_verification_run_passed) { + results_.back().disposition = Disposition::kPassed; + } + } + + // Return true means continue profiling + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Measures performance results +bool SparseGemmOperationProfiler::profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) { + + if (options.profiling.provider_enabled(library::Provider::kCUTLASS)) { + + // Initialize structure containing GEMM arguments + gemm_workspace_.arguments.A = gemm_workspace_.A->data(); + gemm_workspace_.arguments.B = gemm_workspace_.B->data(); + gemm_workspace_.arguments.C = gemm_workspace_.C->data(); + gemm_workspace_.arguments.D = gemm_workspace_.Computed->data(); + gemm_workspace_.arguments.E = gemm_workspace_.E->data(); + gemm_workspace_.arguments.alpha = problem_.alpha.data(); + gemm_workspace_.arguments.beta = problem_.beta.data(); + gemm_workspace_.arguments.pointer_mode = library::ScalarPointerMode::kHost; + + results_.back().status = profile_cutlass_( + results_.back().runtime, + options, + operation, + &gemm_workspace_.arguments, + gemm_workspace_.host_workspace.data(), + gemm_workspace_.device_workspace.data() + ); + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/profiler/src/sparse_gemm_operation_profiler.h b/tools/profiler/src/sparse_gemm_operation_profiler.h new file mode 100644 index 0000000000..37905d3b88 --- /dev/null +++ b/tools/profiler/src/sparse_gemm_operation_profiler.h @@ -0,0 +1,208 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "gemm_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class SparseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SparseGemmProblem { + int64_t m; + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t lde; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + static int const sparse = 2; + // every 128b ElementA uses one elementE + int elements_per_128b; + + // + // Methods + // + + SparseGemmProblem(): + m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SparseGemmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *E; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SparseGemmConfiguration configuration; + library::SparseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SparseGemmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + // GEMM problem + SparseGemmProblem problem_; + + /// Device memory allocations + SparseGemmWorkspace gemm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SparseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SparseGemmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/util/include/cutlass/util/exceptions.h b/tools/util/include/cutlass/util/exceptions.h index b6cf2fcd8e..519205f6d2 100644 --- a/tools/util/include/cutlass/util/exceptions.h +++ b/tools/util/include/cutlass/util/exceptions.h @@ -49,14 +49,9 @@ class cuda_exception : public std::exception { cudaError_t err; }; -/// Writes a cudaError_t to an output stream -inline std::ostream& operator<<(std::ostream& out, cudaError_t result) { - return out << cudaGetErrorString(result); -} - /// Writes a cuda_exception instance to an output stream inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { - return out << e.what() << ": " << e.cudaError(); + return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); } } // namespace cutlass diff --git a/tools/util/include/cutlass/util/host_reorder.h b/tools/util/include/cutlass/util/host_reorder.h index d46d45946f..1d12add3ef 100644 --- a/tools/util/include/cutlass/util/host_reorder.h +++ b/tools/util/include/cutlass/util/host_reorder.h @@ -37,6 +37,8 @@ namespace cutlass { +/// This is needed for the interleaved integer tensor core kernels. The purpose +/// is to use skip the shared memory part in the epilogue. template void reorder_column(TensorRef dest, TensorRef src, @@ -60,4 +62,32 @@ void reorder_column(TensorRef dest, } } +/// This is needed for the sparse tensor core kernels. The purpose +/// is to use ldmatrix to load from shared memory to the register file. +template +void reorder_meta(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + for (int m = 0; m < problem_size.m(); m++) { + for (int k = 0; k < problem_size.k(); k++) { + // First reorder the rows. + int group = (sizeof(Element) == 2) ? 32 : 16; + int interweave = (sizeof(Element) == 2) ? 4 : 2; + + int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; + int dest_col = k; + + // Next swizzle the 2x2 blocks from Z to N. + if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { + ++dest_row; + --dest_col; + } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { + --dest_row; + ++dest_col; + } + + dest.at({dest_row, dest_col}) = src.at({m, k}); + } + } +} } // namespace cutlass diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index c734a5f5eb..465d74a93b 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -39,7 +39,6 @@ #include #include "cutlass/cutlass.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_ref.h" #include "cutlass/tensor_view.h" diff --git a/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/tools/util/include/cutlass/util/host_tensor_planar_complex.h index 3a31e29a43..6bdc8fe47b 100644 --- a/tools/util/include/cutlass/util/host_tensor_planar_complex.h +++ b/tools/util/include/cutlass/util/host_tensor_planar_complex.h @@ -39,7 +39,6 @@ #include #include "cutlass/cutlass.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_ref_planar_complex.h" #include "cutlass/tensor_view_planar_complex.h" diff --git a/tools/util/include/cutlass/util/host_uncompress.h b/tools/util/include/cutlass/util/host_uncompress.h new file mode 100644 index 0000000000..8b630030e5 --- /dev/null +++ b/tools/util/include/cutlass/util/host_uncompress.h @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief uncompress sparse matrix from the host side +*/ +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +template +void uncompress(TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef tensor_e, int row, int col) { + // How many uncompressed data we can get with ElementE meta data + int DecompressedElementsPerElementE = + 256 / cutlass::sizeof_bits::value; + + // Process 4bit meta data a time + int step; + + // 1:2 or 2:4 or 4:8 + int a, b; + + if (cutlass::sizeof_bits::value == 4) { + step = 8; + a = 4; + b = 8; + } else if (cutlass::sizeof_bits::value == 8) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 16) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 32) { + step = 2; + a = 1; + b = 2; + } + + int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; + + for (int r = 0; r < row; ++r) { + for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { + + ElementE meta = tensor_e.at(MatrixCoord(r, c)); + + for (int i = 0; i < DecompressedElementsPerElementE; i += step) { + int e = (meta >> (i / step * 4)) & 0xf; + int idx0 = e & 0x3; + int idx1 = e >> 2; + + if (a == 1) idx0 = idx0 / 2; + + for (int ii = 0; ii < step; ii += ElementsPerE) { + int real_col = + c * DecompressedElementsPerElementE + i + ii; + int compressed_col = (real_col / b) * a; + + if (ii == (idx0 * ElementsPerE)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at(MatrixCoord(r, compressed_col + 1)); + } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at( + MatrixCoord(r, compressed_col + ElementsPerE + 1)); + } else { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + ElementA(0); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + ElementA(0); + } + } + } + } + } +} +} // namespace cutlass + diff --git a/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h new file mode 100644 index 0000000000..db00e712ed --- /dev/null +++ b/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h @@ -0,0 +1,88 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + + int64_t prod = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - Index; i < Rank; ++i) { + prod *= int64_t(extent[i]); + } + + coord[Rank - Index - 1] = int(idx / prod); + + int64_t residual = idx % prod; + LinearToCoordinateHelper()(coord, residual, extent); + } +}; + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + coord[Rank - 1] = int(idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinate { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + LinearToCoordinateHelper()(coord, idx, extent); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/util/include/cutlass/util/reference/device/gemm.h b/tools/util/include/cutlass/util/reference/device/gemm.h index 5aef19ff23..3e4bfb31b6 100644 --- a/tools/util/include/cutlass/util/reference/device/gemm.h +++ b/tools/util/include/cutlass/util/reference/device/gemm.h @@ -34,7 +34,6 @@ #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" diff --git a/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/tools/util/include/cutlass/util/reference/device/gemm_complex.h new file mode 100644 index 0000000000..7c736603bb --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/gemm_complex.h @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h index b3003409bb..b9bdbfa026 100644 --- a/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h +++ b/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -36,7 +36,6 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_ref_planar_complex.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" diff --git a/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/tools/util/include/cutlass/util/reference/device/kernel/gemm.h index 4c8e361ecb..3b9688d17a 100644 --- a/tools/util/include/cutlass/util/reference/device/kernel/gemm.h +++ b/tools/util/include/cutlass/util/reference/device/kernel/gemm.h @@ -29,7 +29,6 @@ #pragma once #include "cutlass/coord.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" diff --git a/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h index 64cb37bea2..8d813ea243 100644 --- a/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h @@ -27,6 +27,8 @@ #include "cutlass/cutlass.h" #include "cutlass/coord.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/fast_math.h" namespace cutlass { namespace reference { @@ -138,7 +140,7 @@ __global__ void BlockForEach( size_t index = threadIdx.x + blockIdx.x * blockDim.x; for (; index < capacity; index += blockDim.x * gridDim.x) { - ptr[index] = func(); + ReferenceFactory::get(ptr, index) = func(); } } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/tools/util/include/cutlass/util/reference/device/tensor_compare.h index 3323bed51e..eb61754e47 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_compare.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -56,8 +56,13 @@ __global__ void BlockCompareEqual( size_t idx = threadIdx.x + blockDim.x * blockIdx.x; for (; idx < capacity; idx += gridDim.x * blockDim.x) { - if (ptr_A[idx] != ptr_B[idx]) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (a != b) { *equal = 0; + return; } } @@ -76,8 +81,8 @@ __global__ void BlockCompareRelativelyEqual( for (; idx < capacity; idx += gridDim.x * blockDim.x) { - Element a = ptr_A[idx]; - Element b = ptr_B[idx]; + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); if (!relatively_equal(a, b, epsilon, nonzero_floor)) { *equal = 0; diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 962ded0940..ff2e5f3666 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -106,6 +106,8 @@ struct RandomGaussianFunc { FloatType mean; FloatType stddev; int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; // // Methods @@ -123,6 +125,9 @@ struct RandomGaussianFunc { stddev(static_cast(stddev_)), int_scale(int_scale_) { + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_up += FloatType(0.5) * float_scale_up; + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -158,8 +163,8 @@ struct RandomGaussianFunc { Element result; if (params.int_scale >= 0) { - rnd = FloatType(IntType(rnd * FloatType(IntType(1) << params.int_scale))); - result = Element(rnd / FloatType(IntType(1) << params.int_scale)); + rnd = FloatType(IntType(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); } else { result = Element(rnd); @@ -188,6 +193,8 @@ struct RandomGaussianFunc> { FloatType mean; FloatType stddev; int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; // // Methods @@ -205,6 +212,9 @@ struct RandomGaussianFunc> { stddev(static_cast(stddev_)), int_scale(int_scale_) { + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_up += FloatType(0.5) * float_scale_up; + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -242,12 +252,12 @@ struct RandomGaussianFunc> { Element result; if (params.int_scale >= 0) { - rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale))); - rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale))); + rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); + rnd_i = FloatType(IntType(rnd_i * params.float_scale_down)); result = { - Real(rnd_r / FloatType(IntType(1) << params.int_scale)), - Real(rnd_i / FloatType(IntType(1) << params.int_scale)) + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) }; } else { @@ -378,7 +388,7 @@ void BlockFillRandomGaussian( namespace detail { /// Computes a random Gaussian distribution -template ///< Layout function +template ///< Element type struct RandomUniformFunc { using FloatType = typename std::conditional< @@ -400,8 +410,10 @@ struct RandomUniformFunc { uint64_t seed; FloatType range; - FloatType min; + FloatType max; int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; /// Default ctor CUTLASS_HOST_DEVICE @@ -414,15 +426,18 @@ struct RandomUniformFunc { /// Construction of Gaussian RNG functor. Params( uint64_t seed_ = 0, - Element max = 1, - Element min_ = 0, + Element max_ = 1, + Element min = 0, int int_scale_ = -1 ): seed(seed_), - range(static_cast(max - min_)), - min(static_cast(min_)), + range(static_cast(max_ - min)), + max(static_cast(max_)), int_scale(int_scale_) { + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_up += FloatType(0.5) * float_scale_up; + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -454,15 +469,15 @@ struct RandomUniformFunc { Element operator()() { FloatType rnd = random_uniform_float(&rng_state); - rnd = params.min + params.range * rnd; + rnd = params.max - params.range * rnd; // Random values are cast to integer after scaling by a power of two to facilitate error // testing Element result; if (params.int_scale >= 0) { - rnd = FloatType(IntType(rnd * FloatType(IntType(1) << params.int_scale))); - result = Element(rnd / FloatType(IntType(1) << params.int_scale)); + rnd = FloatType(IntType(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); } else { result = Element(rnd); @@ -473,7 +488,7 @@ struct RandomUniformFunc { }; /// Computes a random Gaussian distribution -template ///< Layout function +template struct RandomUniformFunc> { using Element = complex; @@ -499,6 +514,8 @@ struct RandomUniformFunc> { FloatType range; FloatType min; int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; /// Default ctor CUTLASS_HOST_DEVICE @@ -520,6 +537,9 @@ struct RandomUniformFunc> { min(static_cast(min_)), int_scale(int_scale_) { + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_up += FloatType(0.5) * float_scale_up; + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); } }; @@ -561,12 +581,12 @@ struct RandomUniformFunc> { Element result; if (params.int_scale >= 0) { - rnd_r = FloatType(IntType(rnd_r * FloatType(IntType(1) << params.int_scale))); - rnd_i = FloatType(IntType(rnd_i * FloatType(IntType(1) << params.int_scale))); + rnd_r = FloatType(IntType(rnd_r * params.float_scale_up)); + rnd_i = FloatType(IntType(rnd_i * params.float_scale_up)); result = { - Real(rnd_r / FloatType(IntType(1) << params.int_scale)), - Real(rnd_i / FloatType(IntType(1) << params.int_scale)) + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) }; } else { @@ -670,7 +690,7 @@ void TensorFillRandomUniform( typename RandomFunc::Params random(seed, max, min, bits); TensorForEach( - view.size(), + view.extent(), Params(view, random) ); } @@ -690,6 +710,7 @@ void BlockFillRandomUniform( /// data. using RandomFunc = detail::RandomUniformFunc; + typename RandomFunc::Params params(seed, max, min, bits); BlockForEach(ptr, capacity, params); @@ -700,7 +721,214 @@ void BlockFillRandomUniform( namespace detail { +/// Computes a random sparse meta +template ///< Element type +struct RandomSparseMetaFunc { + + using FloatType = float; + + using IntType = int32_t; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + int MetaSizeInBits; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), + MetaSizeInBits(MetaSizeInBits_) { + if (MetaSizeInBits_ == 2) { + range = 6; + } else if (MetaSizeInBits_ == 4) { + range = 2; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomSparseMetaFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element *MetaArray = + (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.range * rnd; + Element meta = MetaArray[(int)rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + /// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomSparseMetaFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + using RandomFunc = detail::RandomSparseMetaFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, MetaSizeInBits); + + TensorForEach( + view.extent(), + Params(view, random) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2) { ///< meta data size + + using RandomFunc = detail::RandomSparseMetaFunc; + + typename RandomFunc::Params params(seed, MetaSizeInBits); + + BlockForEach(ptr, capacity, params); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. template < typename Element, ///< Element type typename Layout> ///< Layout function @@ -734,7 +962,6 @@ struct TensorFillDiagonalFunc { // Methods // - /// Construction of Gaussian RNG functor. Params( TensorView view_ = TensorView(), Element diag_ = Element(1), @@ -762,7 +989,7 @@ struct TensorFillDiagonalFunc { } - /// Compute random value and update RNG state + /// Updates the tensor CUTLASS_DEVICE void operator()(TensorCoord const &coord) { @@ -797,7 +1024,7 @@ void TensorFillDiagonal( typedef typename Func::Params Params; TensorForEach( - view.size(), + view.extent(), Params(view, diag, other) ); } @@ -928,7 +1155,7 @@ void TensorUpdateDiagonal( typedef typename Func::Params Params; TensorForEach( - view.size(), + view.extent(), Params(view, diag) ); } @@ -1034,7 +1261,7 @@ void TensorUpdateOffDiagonal( typedef typename Func::Params Params; TensorForEach( - view.size(), + view.extent(), Params(view, other) ); } @@ -1137,7 +1364,7 @@ void TensorFillLinear( using Params = typename Func::Params; TensorForEach( - view.size(), + view.extent(), Params(view, v, s) ); } @@ -1290,7 +1517,7 @@ void TensorCopyDiagonalIn( using Params = typename Func::Params; TensorForEach( - view.size(), + view.extent(), Params(view, ptr) ); } @@ -1394,7 +1621,7 @@ void TensorCopyDiagonalOut( using Params = typename Func::Params; TensorForEach( - view.size(), + view.extent(), Params(view, ptr) ); } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h index d03080b2a0..54621006e1 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_foreach.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -95,7 +95,7 @@ struct BlockForEach { BlockForEach( Element *ptr, size_t capacity, - typename Func::Params params = typename Func::Params(), + typename Func::Params params = typename Func::Params(), int grid_size = 0, int block_size = 0) { diff --git a/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/tools/util/include/cutlass/util/reference/device/tensor_reduce.h new file mode 100644 index 0000000000..a268c92526 --- /dev/null +++ b/tools/util/include/cutlass/util/reference/device/tensor_reduce.h @@ -0,0 +1,505 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/detail/linear_to_coordinate.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t size = view.size(); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + + // Fetch element + Element x = view.at(coord); + + // Transform + identity = reduce(identity, transform(x)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t size = view_A.size(); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + + // Fetch element + Element a = view_A.at(coord); + Element b = view_B.at(coord); + + // Transform + identity = reduce(identity, transform(a, b)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + + +template < + typename ComputeType, + typename ReduceOp, + int kBlockSize = 32 +> +__global__ void TensorTransformReduceFinalize( + ComputeType *workspace, + ComputeType identity, + int workspace_size, + ReduceOp reduce) { + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { + identity = reduce(identity, workspace[idx]); + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[0] = identity; + } +} + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of two tensors, zipped together +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Extents must be equal."); + } + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view_A, view_B, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view_A, + view_B, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform, stream, workspace_size); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/tools/util/include/cutlass/util/reference/device/thread/gemm.h index 11485a91de..318e6c8368 100644 --- a/tools/util/include/cutlass/util/reference/device/thread/gemm.h +++ b/tools/util/include/cutlass/util/reference/device/thread/gemm.h @@ -29,7 +29,6 @@ #pragma once #include "cutlass/coord.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" diff --git a/tools/util/include/cutlass/util/reference/host/gemm.h b/tools/util/include/cutlass/util/reference/host/gemm.h index 3e38886dd8..98db6dcd95 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm.h +++ b/tools/util/include/cutlass/util/reference/host/gemm.h @@ -33,7 +33,6 @@ #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" #include "cutlass/arch/mma.h" diff --git a/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_complex.h index 27f368200d..473115ff87 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm_complex.h +++ b/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -34,7 +34,6 @@ #include "cutlass/functional.h" #include "cutlass/numeric_conversion.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" @@ -73,7 +72,12 @@ void GemmComplex( ScalarType beta, TensorRef tensor_c, TensorRef tensor_d, - ComputeType initial_accum) { + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { static_assert( LayoutA::kRank == 2 && @@ -92,61 +96,72 @@ void GemmComplex( ConvertOp convert_op; InnerProductOp inner_product_op; - for (int row_block = 0; row_block < M; row_block += Mblock) { - for (int col_block = 0; col_block < N; col_block += Nblock) { + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { - ComputeType accum[Mblock][Nblock]; + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - accum[i][j] = initial_accum; - } - } + ComputeType accum[Mblock][Nblock]; - for (int k_block = 0; k_block < K; ++k_block) { for (int j = 0; j < Nblock; j++) { for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; + accum[i][j] = initial_accum; + } + } - if (row < M && col < N) { - ElementA a = tensor_a.at(MatrixCoord(row, k_block)); - ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; - ComputeType a_ik = ComputeType(a); - ComputeType b_kj = ComputeType(b); + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); - if (transform_a == ComplexTransform::kConjugate) { - a_ik = conj(a_ik); - } + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); - if (transform_b == ComplexTransform::kConjugate) { - b_kj = conj(b_kj); - } + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } - accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } } } } - } - for (int j = 0; j < Nblock; j++) { - for (int i = 0; i < Mblock; i++) { - int row = row_block + i; - int col = col_block + j; + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; - MatrixCoord coord = MatrixCoord(row, col); + MatrixCoord coord = MatrixCoord(row, col); - if (row < M && col < N) { + if (row < M && col < N) { - tensor_d.at(coord) = convert_op( - alpha * ScalarType(accum[i][j]) + - beta * ScalarType(tensor_c.at(coord))); + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } } } - } - } - } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h index 2a23fd2720..127c501bd3 100644 --- a/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h +++ b/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -35,7 +35,6 @@ #include "cutlass/numeric_conversion.h" #include "cutlass/tensor_ref_planar_complex.h" -#include "cutlass/matrix_traits.h" #include "cutlass/tensor_view.h" #include "cutlass/gemm/gemm.h" diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 87c14d61c6..1a0230b55d 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -38,6 +38,7 @@ #include "cutlass/complex.h" #include "cutlass/array.h" #include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" #include "cutlass/tensor_view.h" #include "cutlass/tensor_view_planar_complex.h" @@ -300,7 +301,6 @@ void TensorFillRandomGaussian( } /////////////////////////////////////////////////////////////////////////////////////////////////// - /// Fills a tensor with random values with a Gaussian distribution. template < typename Element ///< Element type @@ -319,11 +319,12 @@ void BlockFillRandomGaussian( detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); for (size_t i = 0; i < capacity; ++i) { - ptr[i] = random_func(); + ReferenceFactory::get(ptr, i) = random_func(); } } /////////////////////////////////////////////////////////////////////////////////////////////////// + /////////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { @@ -510,7 +511,6 @@ void TensorFillRandomUniform( } /////////////////////////////////////////////////////////////////////////////////////////////////// - /// Fills a tensor with random values with a uniform random distribution. template < typename Element ///< Element type @@ -527,11 +527,10 @@ void BlockFillRandomUniform( detail::RandomUniformFunc random_func(seed, max, min, bits); for (size_t i = 0; i < capacity; ++i) { - ptr[i] = random_func(); + ReferenceFactory::get(ptr, i) = random_func(); } } -/////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// namespace detail { @@ -879,6 +878,135 @@ void BlockFillRandom( /////////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +template +struct RandomSparseMetaFunc { + + uint64_t seed; + double range; + int MetaSizeInBits; + + // + // Methods + // + + RandomSparseMetaFunc( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), MetaSizeInBits(MetaSizeInBits_) { + std::srand((unsigned)seed); + if (MetaSizeInBits_ == 2) { + range = 6; + } else if (MetaSizeInBits_ == 4) { + range = 2; + } + } + + /// Compute random value and update RNG state + Element operator()() const { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + rnd = range * rnd; + Element meta = MetaArray[(int)rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random sparse meta +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomSparseMetaFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillRandomSparseMetaFunc( + TensorView view_ = TensorView(), + RandomSparseMetaFunc func_ = RandomSparseMetaFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + + view.at(coord) = func(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4 bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + detail::TensorFillRandomSparseMetaFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + /// Copies a diagonal in from host memory without modifying off-diagonal elements. template < typename Element, ///< Element type @@ -891,7 +1019,7 @@ void TensorCopyDiagonalIn( for (typename Layout::Index i = 0; i < extent; ++i) { Coord coord(i); - dst.at(coord) = ptr[i]; + dst.at(coord) = ReferenceFactory::get(ptr, i); } } @@ -910,7 +1038,7 @@ void TensorCopyDiagonalOut( for (typename Layout::Index i = 0; i < extent; ++i) { Coord coord(i); - ptr[i] = src.at(coord); + ReferenceFactory::get(ptr, i) = src.at(coord); } } diff --git a/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/tools/util/include/cutlass/util/reference/host/tensor_norm.h index 1d494b9f45..c2958e32e3 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_norm.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_norm.h @@ -24,53 +24,13 @@ **************************************************************************************************/ #pragma once -#include #include "cutlass/cutlass.h" -#include "cutlass/complex.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/util/reference/host/tensor_foreach.h" +// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. -namespace cutlass { -namespace reference { -namespace host { +#include "cutlass/util/reference/host/tensor_reduce.h" /////////////////////////////////////////////////////////////////////////////////////////////////// -/// Computes the p=2 norm of the elements of a tensor with arbitrary reduction data type. -template < - typename Element, - typename Layout, - typename ElementReduction -> - ElementReduction TensorNorm( - TensorView view, - ElementReduction accumulator) { - TensorForEachLambda( - view.extent(), - [&](typename Layout::TensorCoord const & coord) { - Element element = Element(view.at(coord)); - accumulator = cutlass::norm_accumulate(element, accumulator); - }); - return std::sqrt(accumulator); -} - -/// Computes the p=2 norm of the elements of a tensor. -template < - typename Element, - typename Layout -> -double TensorNorm(TensorView view) { - - return TensorNorm(view, 0); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace host -} // namespace reference -} // namespace cutlass - -/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/tools/util/include/cutlass/util/reference/host/tensor_reduce.h new file mode 100644 index 0000000000..dd1d4fda66 --- /dev/null +++ b/tools/util/include/cutlass/util/reference/host/tensor_reduce.h @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/reference/detail/linear_to_coordinate.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < view.size(); ++idx) { + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + Element x = view.at(coord); + identity = reduce(identity, transform(x)); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Tensor extents must match."); + } + + for (int64_t idx = 0; idx < view_A.size(); ++idx) { + + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + Element a = view_A.at(coord); + Element b = view_B.at(coord); + identity = reduce(identity, transform(a, b)); + } + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////////