Skip to content

Commit

Permalink
CUTLASS 2.7 (NVIDIA#318)
Browse files Browse the repository at this point in the history
CUTLASS 2.7

Mainloop fusion for GEMM: summation over A or B
Strided DGRAD (optimized iterators)
Half-precision GELU_taylor activation functions
Use these when accumulation and epilogue compute types are all cutlass::half_t
Tuning and bug fixes to fused GEMM + GEMM example
Support for smaller than 128b aligned Convolutions: see examples
Caching of results to accelerate Convolution unit tests
Can be enabled or disabled by running cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF
Corrections and bug fixes reported by the CUTLASS community
Thank you for filing these issues!

authored-by: Haicheng Wu haichengw@nvidia.com, Manish Gupta manigupta@nvidia.com, Dustyn Blasig dblasig@nvidia.com, Andrew Kerr akerr@nvidia.com
  • Loading branch information
Manish Gupta authored Sep 20, 2021
1 parent 9ac2558 commit 2e07c4c
Show file tree
Hide file tree
Showing 62 changed files with 5,611 additions and 186 deletions.
12 changes: 11 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# NVIDIA CUTLASS Changelog

# CUTLASS 2.x
## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24)
* Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
* [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
* [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196)
* Use these when accumulation and epilogue compute types are all `cutlass::half_t`
* Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/)
* Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
* Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
* Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF`
* Corrections and bug fixes reported by the CUTLASS community
* Thank you for filing these issues!

## [2.6.1](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.1) (2021-09-03)
* Arbitrary padding and striding for CUTLASS Strided DGRAD Convolution operator (Analytic Iterators)
Expand Down
36 changes: 33 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ endif()

message(STATUS "CMake Version: ${CMAKE_VERSION}")

project(CUTLASS VERSION 2.6.0 LANGUAGES CXX)
project(CUTLASS VERSION 2.7.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)

if (CUDA_VERSION VERSION_LESS 10.2)
Expand Down Expand Up @@ -188,10 +188,18 @@ set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kern

# Test Levels L0, L1, L2
set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")


set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests")

set_property(CACHE CUTLASS_TEST_LEVEL PROPERTY STRINGS 0 1 2)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})

if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
endif()

#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
#
Expand Down Expand Up @@ -244,7 +252,7 @@ if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
file(MAKE_DIRECTORY ${CUTLASS_NVCC_KEEP_DIR})
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep) # --keep-dir may not work with nvcc for some directories.
list(APPEND CUTLASS_CUDA_NVCC_FLAGS --keep -v) # --keep-dir may not work with nvcc for some directories.
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -save-temps=${CUTLASS_NVCC_KEEP_DIR})
endif()

Expand Down Expand Up @@ -572,17 +580,30 @@ function(cutlass_add_executable_tests NAME TARGET)
# TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments
# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of
# options given. If this option is not used, a single test with no arguments is generated.
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
# test results to speed up test runtime.
#

set(options DISABLE_EXECUTABLE_INSTALL_RULE)
set(oneValueArgs DISABLE_TESTS)
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

if (NOT DEFINED __DISABLE_TESTS)
set(__DISABLE_TESTS OFF)
endif()

if (__RESULT_CACHE_FILE)

add_custom_command(
TARGET ${TARGET}
POST_BUILD
COMMAND ${CMAKE_COMMAND}
ARGS -E copy ${__RESULT_CACHE_FILE} "$<TARGET_FILE_DIR:${TARGET}>"
)

endif()

if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS)

# file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR})
Expand All @@ -591,6 +612,15 @@ function(cutlass_add_executable_tests NAME TARGET)
TARGETS ${TARGET}
RUNTIME DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}
)

if (__RESULT_CACHE_FILE)

install(
FILES ${__RESULT_CACHE_FILE}
DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR}/
)

endif()

endif()

Expand Down
34 changes: 22 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")

# CUTLASS 2.6
# CUTLASS 2.7

_CUTLASS 2.6.1 - September 2021_
_CUTLASS 2.7 - September 2021_

CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
It incorporates strategies for hierarchical decomposition and data movement similar
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
reusable, modular software components abstracted by C++ template classes. These
thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
high-performance matrix-multiplication (GEMM) and related computations at all levels
and scales within CUDA. It incorporates strategies for hierarchical decomposition and
data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes
these "moving parts" into reusable, modular software components abstracted by C++ template
classes. These thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized
and tuned via custom tiling sizes, data types, and other algorithmic policy. The
resulting flexibility simplifies their use as building blocks within custom kernels
and applications.
Expand All @@ -20,14 +20,14 @@ multiply-accumulate abstractions for half-precision floating
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
single-precision floating point (FP32), double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).

Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta, Turing, and Ampere architectures.

Additionaly, CUTLASS implements high-performance convolution (implicit GEMM).
Implicit GEMM is the formulation of a convolution operation as a GEMM. This allows CUTLASS
to build convolutions by reusing highly optimized warp-wide GEMM components and below.
CUTLASS implements high-performance Convolution via the implicit GEMM algorithm.
Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of
CUTLASS's modular GEMM pipeline.
This allows CUTLASS to build convolutions by reusing highly optimized warp-wide GEMM components and below.

See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.

Expand All @@ -36,6 +36,16 @@ supported at each level of the execution model hierarchy.

See the [CHANGELOG](CHANGELOG.md) for descriptions of recent updates.

# What's New in CUTLASS 2.7
CUTLASS 2.7 is a minor update to CUTLASS adding:
- Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu)
- [Optimizations for strided DGRAD](/include/cutlass/conv/kernel/default_conv2d_dgrad.h)
- [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196)
- Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/)
- Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272)
- Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h)
- Numerous updates from the community (thanks!)

# What's New in CUTLASS 2.6
CUTLASS 2.6 is a minor update to CUTLASS adding:
- Fused [broadcast](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) and [reductions](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) in the epilogues of GEMM and Convolution
Expand Down
2 changes: 1 addition & 1 deletion examples/03_visualize_layout/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

set(TEST_COMMAND_00 RowMajor --extent=16,16)
set(TEST_COMMAND_01 "ColumnMajorInterleaved<4>" --extent=32,8 --output-shape=16 --vectorize=4)
set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4)

cutlass_example_add_executable(
03_visualize_layout
Expand Down
28 changes: 28 additions & 0 deletions include/cutlass/arch/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,34 @@ struct global_store;
//
/////////////////////////////////////////////////////////////////////////////////////////////////


template <typename AccessType>
struct global_store<AccessType, 64> {
CUTLASS_DEVICE
global_store(AccessType const &D, void *ptr, bool pred_guard) {
uint4 const *data = reinterpret_cast<uint4 const *>(&D);

asm volatile(
"{\n"
" .reg .pred p;\n"
" setp.ne.b32 p, %5, 0;\n"
" @p st.global.v4.u32 [%0], {%1, %2, %3, %4};\n"
" @p st.global.v4.u32 [%6], {%7, %8, %9, %10};\n"
" @p st.global.v4.u32 [%11], {%12, %13, %14, %15};\n"
" @p st.global.v4.u32 [%16], {%17, %18, %19, %20};\n"
"}\n"
:
: "l"(ptr), "r"(data[0].x), "r"(data[0].y), "r"(data[0].z),
"r"(data[0].w), "r"((int)pred_guard), "l"(((uint8_t *)ptr) + 16),
"r"(data[1].x), "r"(data[1].y), "r"(data[1].z), "r"(data[1].w),
"l"(((uint8_t *)ptr) + 32),
"r"(data[2].x), "r"(data[2].y), "r"(data[2].z), "r"(data[2].w),
"l"(((uint8_t *)ptr) + 48),
"r"(data[3].x), "r"(data[3].y), "r"(data[3].z), "r"(data[2].w));
}
};


template <typename AccessType>
struct global_store<AccessType, 32> {
CUTLASS_DEVICE
Expand Down
Loading

0 comments on commit 2e07c4c

Please sign in to comment.