Skip to content

Commit

Permalink
CUTLASS 3.4.0 (NVIDIA#1286)
Browse files Browse the repository at this point in the history
* CUTLASS 3.4.0

* Update CHANGELOG.md

---------

Co-authored-by: Pradeep Ramani <prramani@nvidia.com>
  • Loading branch information
IonThruster and IonThruster authored Dec 29, 2023
1 parent b7508e3 commit 8236f30
Show file tree
Hide file tree
Showing 211 changed files with 11,336 additions and 2,690 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# NVIDIA CUTLASS Changelog
## [3.4](https://github.com/NVIDIA/cutlass/releases/tag/v3.4) (2023-12-29)
* Expanded [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors.
* Performance improvements to [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm)
* Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above).
* NamedBarriers usability improvement and list of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) has been officially released.
* Improved [CuTe TMA Tensor](/media/docs/cute/0z_tma_tensors.md) documentation.


## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31)
* [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
Expand Down
33 changes: 25 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")

project(CUTLASS VERSION 3.3.0 LANGUAGES CXX)
project(CUTLASS VERSION 3.4.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)

if (CUDA_VERSION VERSION_LESS 11.3)
Expand Down Expand Up @@ -681,6 +681,12 @@ endif()

################################################################################

set(CUTLASS_DEFAULT_ACTIVE_TEST_SETS "default" CACHE STRING "Default
activated test sets. In `make test` mode, this string determines the
active set of tests. In `ctest` mode, this value can be overriden
with CUTLASS_TEST_SETS environment variable when running the ctest
executable.")

set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake)
set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "")

Expand All @@ -701,11 +707,12 @@ function(cutlass_add_executable_tests NAME TARGET)
# generating the full variable name to be referenced.
# RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed
# test results to speed up test runtime.
# TEST_SETS_SUPPORTED: A list of test set names these tests support.
#

set(options DISABLE_EXECUTABLE_INSTALL_RULE)
set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS)
set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED)
cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})

if (NOT DEFINED __DISABLE_TESTS)
Expand All @@ -715,6 +722,12 @@ function(cutlass_add_executable_tests NAME TARGET)
set(TEST_EXE $<TARGET_FILE_NAME:${TARGET}>)
set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR})

if (NOT DEFINED __TEST_SETS_SUPPORTED)
set(__TEST_SETS_SUPPORTED ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})
endif()

set(TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED})

if (__RESULT_CACHE_FILE)

add_custom_command(
Expand Down Expand Up @@ -816,8 +829,6 @@ function(cutlass_add_executable_tests NAME TARGET)
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
file(MAKE_DIRECTORY ${TEST_GEN_DIR})

set(TEST_SETS_SUPPORTED default)

set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT ON)
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
Expand Down Expand Up @@ -873,9 +884,9 @@ if (CUTLASS_INSTALL_TESTS)
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")

file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n")

file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SET})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SET} \"default\")\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "cmake_policy(SET CMP0057 NEW) # Allow IN_LIST for if()\n\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SETS})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SETS} ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n")

foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
Expand All @@ -897,9 +908,15 @@ write_basic_package_version_file(
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
COMPATIBILITY AnyNewerVersion)

configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
@ONLY
)

install(
FILES
${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/
)
Expand Down
4 changes: 4 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Cris Cecka<br />
Aniket Shivam<br />
Jack Kosaian<br />
Mark Hoemmen<br />
Richard Cai<br />
Honghao Lu<br />
Ethan Yan<br />
Haicheng Wu<br />
Expand All @@ -21,6 +22,8 @@ Dustyn Blasig<br />
Fengqi Qiao<br />
Duane Merrill<br />
Yujia Zhai<br />
Rawn Henry<br />
Sergey Klevtsov<br />
Shang Zhang<br />
Piotr Majcher<br />
Paul Springer<br />
Expand Down Expand Up @@ -55,6 +58,7 @@ Alan Kaatz<br />
Tina Li<br />
Timmy Liu<br />
Wei Liu<br />
Tim Martin<br />
Duane Merrill<br />
Kevin Siu<br />
Markus Tavenrath<br />
Expand Down
6 changes: 5 additions & 1 deletion CUDA.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,15 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR)
message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required")
endif()

if (NOT DEFINED __BATCH_SOURCES)
set(__BATCH_SOURCES ON)
endif()

if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE)
set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE})
endif()

if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1)
if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1)

set(CUDA_FILE_ARGS)
set(TARGET_SOURCE_ARGS)
Expand Down
22 changes: 10 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")

# CUTLASS 3.3
# CUTLASS 3.4

_CUTLASS 3.3 - October 2023_
_CUTLASS 3.4 - December 2023_

CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
Expand Down Expand Up @@ -41,17 +41,14 @@ and improves code composability and readability. More documentation specific to

In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.

# What's New in CUTLASS 3.3
# What's New in CUTLASS 3.4

CUTLASS 3.3.0 is an update to CUTLASS adding:
CUTLASS 3.4.0 is an update to CUTLASS adding:

- New [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input types with optimal performance.
- New [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8} and upcast on operandA {s8, u8} x {fp16, bf16}. They also include fast numeric conversion recipes and warp level shuffles to achieve optimal performance.
- New [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors (across s8/fp8/fp16/bf16/tf32 types) with optimal performance. As a part of this, new kernel schedules, and Copy Ops [SM80\_CP\_ASYNC\_CACHE\_\*](/include/cute/arch/copy_sm80.hpp) were also added.
- EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
- Various subbyte enhancements like tagged device ptrs, support for vectorized copy, various operators to treat subbyte iterators as pointers, and full-fledged CuTe Tensor support.
- Support for Clang as a host compiler.
- Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface
- Improved [Mixed-input Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) supporting {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors tuned for optimal performance on Hopper H100.
- Beta release of [Pointer-Array Batched GEMMs](/examples/56_hopper_ptr_array_batched_gemm) utilizing TMA and Hopper H100 tensor cores now available. (Requires CUDA 12.3 or above)
- Beta release of [Group-GEMM](/examples/57_hopper_grouped_gemm) - commonly used in optimization of Mixture-Of-Expert models, is now available on Hopper GPUs taking advantage of TMA and Hopper H100 tensor cores. (Requires CUDA 12.3 or above)
- Impovements to NamedBarriers including details of [ReservedNamedBarriers](/include/cutlass/arch/barrier.h) used within the CUTLASS library.

Minimum requirements:

Expand Down Expand Up @@ -95,7 +92,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA

CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1.
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2 and CUDA 12.3.1

## Operating Systems
We have tested the following environments.
Expand All @@ -107,6 +104,7 @@ We have tested the following environments.
| Ubuntu 22.04 | GCC 11.2.0 |
| Ubuntu 22.04 | Clang 10.0.0 |
| Ubuntu 22.04 | Clang 14.0.6 |
| Ubuntu 22.04 | Clang 17.0.6 |
| Windows 10.0 | Visual Studio 2019 v16.11.27 |

Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.
Expand Down
14 changes: 10 additions & 4 deletions cmake/CTestTestfile.configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@

set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)

#? if (DEFINED ENV{CUTLASS_TEST_SET} AND NOT ENV{CUTLASS_TEST_SET} IN_LIST TEST_SETS_SUPPORTED)
#? message(STATUS "Skipping tests for @TEST_EXE_PATH@ as $ENV{CUTLASS_TEST_SET} is not in the set of ${TEST_SETS_SUPPORTED}.")
#? return()
#? endif()
if (NOT DEFINED ENV{CUTLASS_TEST_SETS})
set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@)
endif()

foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS})
if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED)
message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].")
return()
endif()
endforeach()

set(TEST_EXE_PATH @TEST_EXE_PATH@)
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,3 @@ if(TARGET nvidia::cutlass::CUTLASS)
endif()

include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake")

# For backward compatibility with the old name
add_library(cutlass_lib ALIAS cutlass_library)
4 changes: 2 additions & 2 deletions examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,8 @@ int run() {
LayoutInputB,
ElementOutput,
LayoutOutput,
int32_t,
int32_t>
ElementComputeEpilogue,
ElementComputeEpilogue>
gemm_device;

// Launch device reference gemm kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {

out << "28_ampere_gemm_bias_fusion example\n\n"
out << "23_ampere_operand_gemm_reduction_fusion\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
<< " --m=<int> GEMM M\n"
Expand All @@ -297,7 +297,7 @@ struct Options {
<< " --tag=<string> String to replicate across the first column in the results table\n";

out << "\n\nExamples:\n\n"
<< "$ ./examples/23_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n";
<< "$ ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --m=1024 --n=1024 --k=1024 \n\n";

return out;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/30_wgrad_split_k/30_wgrad_split_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ Result profile_convolution(Options const &options) {

std::stringstream ss;

ss << "26_ampere_fused_wgrad_batch_normalization_"
ss << "30_wgrad_split_k_"
<< options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c()
<< "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c()
Expand Down
2 changes: 1 addition & 1 deletion examples/34_transposed_conv2d/34_transposed_conv2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ struct Options {
<< " --tag=<string> String to replicate across the first column in the results table\n";

out << "\n\nExamples:\n\n"
<< "$ ./examples/31_transposed_conv2d/31_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n";
<< "$ ./examples/34_transposed_conv2d/34_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n";

return out;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/38_syr2k_grouped/syr2k_grouped.cu
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ struct Options {
<< "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n"

<< "# Execute Grouped SYR2K and profile with NSight\n"
<< "$ nv-nsight-cu-cli ./examples/24_gemm_grouped/24_gemm_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n";
<< "$ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n";

return out;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ struct Options {

/// Prints the usage statement.
std::ostream &print_usage(std::ostream &out) const {
out << "41_depthwise_gemm_fprop example\n\n"
out << "46_depthwise_gemm_fprop example\n\n"
<< " This example uses Ampere's Tensor Core operators on F16 data types to compute\n"
<< " forward convolution on tensors of layout NHWC.\n\n"
<< "Options:\n\n"
Expand Down Expand Up @@ -554,7 +554,7 @@ Result profile_convolution(Options const &options) {
if (options.save_workspace) {
std::stringstream ss;

ss << "45_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
ss << "46_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h()
<< "x" << options.input_size.w() << "x" << options.input_size.c() << "_"
<< options.filter_size.n() << "x" << options.filter_size.h() << "x"
<< options.filter_size.w() << "x" << options.filter_size.c() << ".dat";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ struct ExampleRunner {
using CustomEVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90SrcFetch<ElementC>, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch // acc
Expand All @@ -302,7 +302,7 @@ struct ExampleRunner {
// Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp
// to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts.
// These tags also provide additional metadata that can be queried at compile time.
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementScalar, RoundStyle>;
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ struct ExampleRunner
if (options.reference_check) {
if (!verify()) {
std::cout << "Failed validation" << std::endl;
#if 1
#if 0
debug_output(std::cout);
#endif
return false;
Expand Down
2 changes: 1 addition & 1 deletion examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class GemmGather
static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same.");
static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup;
static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup;
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup;
static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups;
static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance.");
Expand Down
7 changes: 5 additions & 2 deletions examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@
namespace example
{

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

struct Options {

bool help;
Expand Down Expand Up @@ -724,6 +726,7 @@ private:
return true;
}
};
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)

} // namespace example

Expand All @@ -749,7 +752,7 @@ int main(int argc, char const **argv)
if (notSupported) {
return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems
}

#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
example::Options options;
options.parse(argc, argv);

Expand Down Expand Up @@ -970,6 +973,6 @@ int main(int argc, char const **argv)
result &= runner.run(options);
}
#endif

return result ? EXIT_SUCCESS : EXIT_FAILURE;
#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,11 @@ using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;

// Auxiliary matrix configuration
// Auxiliary matrix configuration and other fusion types
using ElementAux = ElementC;
using LayoutAux = LayoutC;
using ElementAmax = float;
using ElementBias = float;

// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
Expand All @@ -124,7 +126,7 @@ using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux<
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux>;
LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
Expand Down
Loading

0 comments on commit 8236f30

Please sign in to comment.