Skip to content

Commit

Permalink
CUTLASS 2.2 (NVIDIA#96)
Browse files Browse the repository at this point in the history
Adds support for NVIDIA Ampere Architecture features. CUDA 11 Toolkit recommended.
  • Loading branch information
kerrmudgeon authored Jun 8, 2020
1 parent e33d90b commit 86931fe
Show file tree
Hide file tree
Showing 584 changed files with 51,095 additions and 3,388 deletions.
16 changes: 16 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@

# CUTLASS 2.x

## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08)
* [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
* Fast Tensor Core operations:
* Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends)
* Tensor Float 32, BFloat16, and double-precision data types
* Mixed integer data types (int8, int4, bin1)
* Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution)
* Features:
* SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM
* Complex-valued GEMMs targeting NVIDIA Ampere Tensor Cores in double-precision and Tensor Float 32
* Gaussian complex GEMMs using 3m complex multiply algorithm
* Universal GEMM kernel supporting two batch modes and two algorithms for parallel reductions
* Policy updates:
* [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit) needed to enable NVIDIA Ampere Architecture features
* Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON`

## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06)
* BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)
* API to launch compiled kernel instances for GEMM and planar complex GEMM
Expand Down
17 changes: 10 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
Expand Down Expand Up @@ -32,7 +32,7 @@ endif()

message(STATUS "CMake Version: ${CMAKE_VERSION}")

project(CUTLASS VERSION 2.1.0 LANGUAGES CXX)
project(CUTLASS VERSION 2.2.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)

find_package(Doxygen QUIET)
Expand Down Expand Up @@ -84,7 +84,7 @@ endif()

set(CUTLASS_NVCC_ARCHS_SUPPORTED "")
if (NOT CUDA_VERSION VERSION_LESS 7.5)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 50)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53)
endif()
if (NOT CUDA_VERSION VERSION_LESS 8.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61)
Expand All @@ -98,6 +98,9 @@ endif()
if (NOT CUDA_VERSION VERSION_LESS 10.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75)
endif()
if (NOT CUDA_VERSION VERSION_LESS 11.0)
list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80)
endif()
set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.")
set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.")

Expand Down Expand Up @@ -154,7 +157,7 @@ endif()
set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
set(CUTLASS_ENABLE_F16C ON CACHE BOOL "Enable F16C x86 extensions in host code.")
set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.")

#
# CUTLASS generator cmake configuration
Expand Down Expand Up @@ -248,8 +251,8 @@ if(CUDA_COMPILER MATCHES "[Cc]lang")
endif()

list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-pragma-unroll-threshold=100000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm=-unroll-threshold=5000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -pragma-unroll-threshold=100000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000)
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument)

string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION})
Expand All @@ -271,7 +274,7 @@ function(cutlass_apply_cuda_gencode_flags TARGET)
set(NVCC_FLAGS)
set(CLANG_FLAGS)
foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH})
set(CODES)
if(CUTLASS_NVCC_EMBED_CUBIN)
list(APPEND CODES sm_${ARCH})
Expand Down
14 changes: 7 additions & 7 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,28 @@ This is the official list of CUTLASS developers and contributors.
## DEVELOPERS
Andrew Kerr
Haicheng Wu
Naila Farooqui
Manish Gupta
Dustyn Blasig
Pradeep Ramani
Manish Gupta
Aditya Atluri
Naila Farooqui
Piotr Majcher
Paul Springer
David Tanner
Scott Yokim
Jin Wang
Scott Yokim
Markus Hohnerbach
Aditya Atluri
David Tanner

## CONTRIBUTORS
Timothy Costa
Julien Demouth
Brian Fahs
Michael Goldfarb
Mostafa Hagog
Markus Hohnerbach
Fei Hu
Alan Kaatz
Tina Li
Timmy Liu
Piotr Majcher
Duane Merrill
Kevin Siu
Markus Tavenrath
Expand Down
6 changes: 3 additions & 3 deletions CUDA.cmake
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification, are permitted
# provided that the following conditions are met:
Expand Down Expand Up @@ -206,14 +206,14 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS})
function(cutlass_correct_source_file_language_property)
if(CUDA_COMPILER MATCHES "clang")
foreach(File ${ARGN})
if(${File} MATCHES ".*\.cu$")
if(File MATCHES ".*\.cu$")
set_source_files_properties(${File} PROPERTIES LANGUAGE CXX)
endif()
endforeach()
endif()
endfunction()

set(CUTLASS_UNITY_BUILD_ENABLED ON CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_ENABLED OFF CACHE BOOL "Enable combined source compilation")
set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files")

function(cutlass_unify_source_files TARGET_ARGS_VAR)
Expand Down
2 changes: 1 addition & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright (c) 2017 - 2019, NVIDIA CORPORATION. All rights reserved.
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
Expand Down
100 changes: 59 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")

# CUTLASS 2.1
# CUTLASS 2.2

_CUTLASS 2.1 - April 2020_
_CUTLASS 2.2 - June 2020_

CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
Expand All @@ -17,22 +17,35 @@ and applications.
To support a wide variety of applications, CUTLASS provides extensive support for
mixed-precision computations, providing specialized data-movement and
multiply-accumulate abstractions for half-precision floating
point (FP16), single-precision floating point (FP32), double-precision floating
point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32),
single-precision floating point (FP32), double-precision floating
point (FP64) types, integer data types (4b and 8b), and binary data types (1b).
Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations for

Furthermore, CUTLASS demonstrates warp-synchronous matrix multiply operations
targeting the programmable, high-throughput _Tensor Cores_ implemented by
NVIDIA's Volta and Turing architectures.
NVIDIA's Volta, Turing, and Ampere architectures.

See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly.

See the [functionality listing](media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.

# What's New in CUTLASS 2.2

CUTLASS 2.2 is a significant update to CUTLASS adding:

- Coverage of [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/)
- Tensor Core-accelerated GEMMs targeting Tensor Float 32, BFloat16, and double-precision data types
- Deep software pipelines using asynchronous copy
- Intended to be compiled with [CUDA 11 Toolkit](https://developer.nvidia.com/cuda-toolkit)

# What's New in CUTLASS 2.1

CUTLASS 2.1 is a minor update to CUTLASS 2.0 adding:

- [Planar complex GEMM kernels](/examples/10_planar_complex/planar_complex.cu) targeting Volta and Turing Tensor Cores
- BLAS-style API to launch kernels compiled into the [CUTLASS Library](/media/docs/quickstart.md#cutlass-library)


# What's New in CUTLASS 2.0

CUTLASS 2.0 is a substantial refactoring from the previous version, intended to offer:
Expand All @@ -43,25 +56,22 @@ CUTLASS 2.0 is a substantial refactoring from the previous version, intended to

**See the [CHANGELOG](CHANGELOG.md) for more details.**

See the [functionality listing](media/docs/functionality.md) for the list of operations
supported at each level of the execution model hierarchy.

# Performance

<p align="center"><img src=/media/images/cutlass-performance-plot.png></p>

CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels,
they exhibit performance comparable to cuBLAS for scalar GEMM
computations. The above figure shows CUTLASS performance relative to cuBLAS
for large matrix dimensions on an NVIDIA GeForce 2080 Ti and an NVIDIA TitanV
using CUDA 10.2. Tensor Core operations are implemented using CUDA's
for large matrix dimensions on an NVIDIA GeForce 2080 Ti, an NVIDIA A100, and an NVIDIA TitanV
using CUDA 11.0 Toolkit. Tensor Core operations are implemented using CUDA's
[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma).

# Compatibility

CUTLASS requires a C++11 host compiler and
performs best when built with the [CUDA 10.2 Toolkit](https://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.2, CUDA 10.0, and CUDA 10.1.
performs best when built with the [CUDA 11.0 Toolkit](https://developer.nvidia.com/cuda-toolkit).
It is compatible with CUDA 9.2, CUDA 10.0, CUDA 10.1, and CUDA 10.2.

We have tested the following environments.

Expand All @@ -70,27 +80,28 @@ We have tested the following environments.
| Windows 10 | Microsoft Visual Studio 2015|
| | Microsoft Visual Studio 2017|
| Ubuntu 16.04 | GCC 5.4.0 |
| Ubuntu 18.04 | GCC 7.3.0 |
| Ubuntu 18.04 | GCC 7.5.0 |

Additionally, CUTLASS may be built with clang.
See [these instructions](media/docs/quickstart.md#clang) for more details.

CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on
any Maxwell-, Pascal-, Volta-, or Turing- architecture NVIDIA GPU.

|**GPU**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|---|---|---|
|NVIDIA GeForce 1080|9.2| |
|NVIDIA TitanXP|9.2| |
|NVIDIA Tesla P100|9.2| |
|NVIDIA Tesla V100|9.2|10.1|
|NVIDIA TitanV|9.2|10.1|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|10.0|10.2|
|NVIDIA Tesla T4|10.0|10.2|
any Maxwell-, Pascal-, Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU.

|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**CUDA Toolkit Enabling Native Tensor Cores**|
|---|---|---|---|
|NVIDIA Tesla P100|6.0|9.2| |
|NVIDIA GeForce 1080|6.1|9.2| |
|NVIDIA TitanXP|6.1|9.2| |
|NVIDIA Tesla V100|7.0|9.2|10.1|
|NVIDIA TitanV|7.0|9.2|10.1|
|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2|
|NVIDIA Tesla T4|7.5|10.0|10.2|
|NVIDIA A100|8.0|11.0|11.0|

# Documentation

CUTLASS 2.1 is described in the following documents and the accompanying
CUTLASS 2.2 is described in the following documents and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).

- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS
Expand Down Expand Up @@ -124,7 +135,7 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc
```

Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, and 8.0. To reduce compile time you can specify
the architectures to build CUTLASS for by changing the CMake configuration setting
`CUTLASS_NVCC_ARCHS`.

Expand Down Expand Up @@ -210,6 +221,10 @@ examples/
10_planar_complex/ # example demonstrating planar complex GEMM kernels
11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes
12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu
13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel
```

### Tools
Expand Down Expand Up @@ -255,29 +270,32 @@ $ make cutlass_profiler -j

Example command line for profiling SGEMM kernels is as follows:
```
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=4352 --n=4096 --k=4096
$ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096
=============================
Problem ID: 1
Provider: CUTLASS
Operation: cutlass_simt_sgemm_128x128_nn
Provider: CUTLASS
OperationKind: gemm
Operation: cutlass_simt_sgemm_128x128_8x2_nn_align1
Status: Success
Verification: ON
Disposition: Passed
Disposition: Passed
Status: Success
cuBLAS: Passed
Arguments: --m=4352 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 \
--split_k_slices=1 --batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 \
--stages=2 --warps_m=2 --warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 \
--max_cc=1024
Arguments: --m=3456 --n=4096 --k=4096 --A=f32:column --B=f32:column --C=f32:column --alpha=1 --beta=0 --split_k_slices=1 \
--batch_count=1 --op_class=simt --accum=f32 --cta_m=128 --cta_n=128 --cta_k=8 --stages=2 --warps_m=4 \
--warps_n=2 --warps_k=1 --inst_m=1 --inst_n=1 --inst_k=1 --min_cc=50 --max_cc=1024
Bytes: 52428800 bytes
FLOPs: 146064539648 flops
Bytes: 180355072 bytes
FLOPs: 115992428544 flops
Runtime: 10.5424 ms
Memory: 4.63158 GiB/s
Runtime: 6.73655 ms
Memory: 24.934 GiB/s
Math: 13854.9 GFLOP/s
Math: 17218.4 GFLOP/s
```

[Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md)
Expand Down
2 changes: 1 addition & 1 deletion cmake/nop.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without modification, are permitted
* provided that the following conditions are met:
Expand Down
Loading

0 comments on commit 86931fe

Please sign in to comment.