Skip to content

Commit

Permalink
CUTLASS 1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
kerrmudgeon committed Oct 26, 2018
1 parent 2332df4 commit 74df033
Show file tree
Hide file tree
Showing 97 changed files with 11,306 additions and 637 deletions.
12 changes: 8 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# NVIDIA CUTLASS Changelog

## [1.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.2.0) (2018-10-26)
* Parallelized reductions across threadblocks ("Split-K")
* Improved IGEMM performance
* Batched strided WMMA GEMMs

## 1.1.0 (2018-09-19)
## [1.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v1.1.0) (2018-09-19)
* Turing Features
* WMMA GEMM targeting TensorCores - INT8, INT4, 1-bit
* WMMA GEMM targeting TensorCores - INT8, INT4, INT1
* Batched Strided GEMM
* Threadblock rasterization strategies
* Improved performance for adverse problem sizes and data layouts
Expand All @@ -16,13 +20,13 @@
* Examples
* Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM

## 1.0.1 (2018-06-11)
## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11)

* Intra-threadblock reduction added for small threadblock tile sizes
* sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16
* igemm_32x32x128
* GEMM _K_ residue handled during prologue prior to mainloop
* Replaced Google Test copy with submodule. Use `git submodule init`
* Replaced Google Test copy with submodule. Use `git submodule init --recursive --update`

## [1.0.0](https://github.com/NVIDIA/cutlass/commit/2028ebe120aab22bfd0b2baf8902d4c9627eb33f) (2018-05-16)

Expand Down
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ else()
string(APPEND NVCC_FLAGS " -lineinfo")
endif()

if (UNIX)
string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
endif()

string(APPEND NVCC_FLAGS_DEBUG " -g")
string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3")
string(APPEND NVCC_FLAGS_RELEASE " -O3")
Expand Down Expand Up @@ -169,6 +173,8 @@ file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
file(GLOB CUTLASS_REDUCTION RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/reduction/*.h )

###################################################################################################
#
# Define build targets
Expand All @@ -178,6 +184,7 @@ file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
source_group("cutlass\\device" FILES ${CUTLASS_DEVICE})
source_group("cutlass\\reduction" FILES ${CUTLASS_REDUCTION})
source_group("cutlass" FILES ${CUTLASS_CORE})

add_library(CUTLASS INTERFACE)
Expand All @@ -187,6 +194,7 @@ target_sources(CUTLASS INTERFACE
${CUTLASS_UTIL}
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
)

target_include_directories(CUTLASS INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
Expand All @@ -197,6 +205,7 @@ add_custom_target(cutlass_ide SOURCES
${CUTLASS_UTIL}
${CUTLASS_DEVICE}
${CUTLASS_CORE}
${CUTLASS_REDUCTION}
)
# Doxygen is available. Generate documentation
if (DOXYGEN_FOUND)
Expand Down
64 changes: 52 additions & 12 deletions CUTLASS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CUTLASS core components, and to identify their role in implementing GEMM computa
2. [General Matrix Multiply](#S-general-matrix-multiply)
3. [Core Components](#S-core-components)
4. [Utilities](#S-utilities)
5. [Optimization Strategies](#S-optimization-strategies)

# <a name="S-design-patterns"></a> 1. Design Patterns

Expand All @@ -26,7 +27,7 @@ objectives. This section is intended to provide more detail.

## <a name="S-patterns-sequencing-nesting"></a> Sequencing and Nesting of Collective Primitives

CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.

## <a name="S-patterns-tiles-iterators"></a> Tiles and Iterators

Expand All @@ -48,7 +49,7 @@ CUTLASS can take advantage of this CUDA grid-invariant property by constructing

The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument.

For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.


## <a name="S-patterns-composable-shared-memory"></a> Composable shared memory allocation
Expand Down Expand Up @@ -94,7 +95,7 @@ multiply operation performed by each iteration of the mainloop is referred to as

The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative
access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular
buffer in shared memory is performed by a _GlobalLoadIterator_.
buffer in shared memory is performed by a _GlobalLoadIterator_.

**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers.

Expand All @@ -109,24 +110,24 @@ The Global Load Stream template contains members defined by the following templa
The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product.
Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores.

[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.

* [GemmSharedLoadTile{A,B}](cutlass/gemm/gemm_shared_tile.h)

**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.

* [WMMA Multiply Add](cutlass/gemm/wmma_gemm_multiply_add.h)

## Thread-level GEMM

SGEMM, IGEMM, HGEMM, and DGEMM are computed by SIMT math instructions issued by thread-level matrix multiply
procedures.
procedures.

* [ThreadMultiplyAdd](cutlass/gemm/thread_multiply_add.h)
* [IGEMM specialization](cutlass/gemm/igemm_multiply_add.h)
* [HGEMM specialization](cutlass/gemm/hgemm_multiply_add.h)

## Epilogue
## Epilogue

The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components:

Expand Down Expand Up @@ -227,7 +228,7 @@ must specify compile-time constant tile sizes.
## <a name="S-core-tile-structure"></a> Tile Structure

Tiled structures express an arrangement of data in memory as well as a logical mapping of concurrent CUDA
threads to the problem space. For example, the CUTLASS GEMM
threads to the problem space. For example, the CUTLASS GEMM

Tiled structures can be defined using the `cutlass::TileTraits<>` concept which defines the following
members. Collectively, these members offer a flexible way to define a 4-D subpartition of an integer
Expand Down Expand Up @@ -286,7 +287,7 @@ the next item in sequence.
<img src="/media/images/cutlass-tile-iteration.png" alt="CUTLASS tile access and traversal" width="50%" />

To offer a generic solution that spans numerous data types and layouts, CUTLASS defines the _TileIterator_ concept.
This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.
This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.

The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).

Expand All @@ -296,9 +297,9 @@ A fragment is analogous to `std::array<>` in that it is a constant-sized array o

## <a name="S-core-predicate-vector"></a> Predicate Vector

SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.

CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.


# <a name="S-utilities"></a> 4. Utilities
Expand All @@ -310,6 +311,46 @@ framework offering features such as:
* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS
* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h)


# <a name="S-optimization-strategies"></a>5. Optimization Strategies

This section describes several strategies taken to increase performance beyond what is achievable with
a basic implementation of the hierarchical GEMM structure.


## Threadblock Rasterization

To maximize reuse of data held in the last level cache, CUTLASS defines several functions to
affect the mapping of threadblocks to logical partitions of the GEMM problem. These map
consecutively launched threadblocks to packed two-dimensional regions of the partitioned GEMM
problem to increase the probability that these will access the same tiles of global memory at
approximately the same time.

Several functions are defined in [cutlass/gemm/threadblock_swizzle.h](cutlass/gemm/threadblock_swizzle.h).


## Parallel Reductions across GEMM _K_

Matrix product computations expose parallelism among _O(MN)_ independent inner product
computations. For sufficiently large problem sizes, a GEMM kernel in CUTLASS may approach
the theoretical maximum computational throughput. For small problems, however, there are
too few threadblocks to efficiently occupy the entire GPU.

As a recourse, parallelizing the reduction performed during the inner product computation
enables more threadblocks to execute concurrently while still taking advantage of the throughput
benefits of large threadblock-level GEMM tiles.

CUTLASS implements parallel reductions across threadblocks by partitioning the GEMM _K_ dimension
and launching an additional set of threadblocks for each partition. Consequently, we refer to
this strategy within CUTLASS as "parallel reduction splitK." The "parallel reduction splitK" in cutlass requires the execution of 2 kernels. The first one is called partitionedK GEMM. The second one is called batched reduction.

The partitionedK GEMM is very similar to one flavor of batched strided GEMM. Instead of requiring users to specify the problem size of each batch, partitionedK GEMM asks for the overall problem size and the number of partition that will be applied along K dimension for operand A and B. For example, parameters of m=128, n=128, k=4096 and partition=16 will result in 16 batched strided GEMMs with each batch of m=128, n=128, k=256. PartitionedK also allows scenario where k is not divisible by partition count. For example, parameters of m=128, n=128, k=4096 and partition=20 will result in 20 batched strided GEMMs with the first 19 batches of m=128, n=128, k=4096/20=204 and the last batch of m=128, n=128, k=220.

The batched reduction kernel will further perform reduction along the K-dimension. Thus, the input of the batched reduction kernel is the output (C) of partitionedK GEMM. An workspace memory is managed by the users to store this intermediate results.

An example of splitK usage can be found [here](examples/06_splitK_gemm/splitK_gemm.cu).


# Copyright

Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
Expand All @@ -335,4 +376,3 @@ Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
```

15 changes: 11 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")

# CUTLASS 1.1
# CUTLASS 1.2

_CUTLASS 1.1.0 - September 2018_
_CUTLASS 1.2.0 - October 2018_

CUTLASS 1.1 is a collection of CUDA C++ template abstractions for implementing
CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
It incorporates strategies for hierarchical decomposition and data movement similar
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
Expand All @@ -22,12 +22,19 @@ point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targe
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
and beyond.

CUTLASS 1.1 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
CUTLASS 1.2 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
[Doxygen documentation](https://nvidia.github.io/cutlass).
We describe the structure of an efficient GEMM in our talk at the
[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).

# What's New in CUTLASS 1.2
_October 2018_
* [Parallelized Reductions](CUTLASS.md#parallel-reductions-across-gemm-k)
* Batched strided WMMA GEMM


# What's New in CUTLASS 1.1
_September 2018_

* [CUTLASS Documentation](CUTLASS.md)
* [Examples](examples/)
Expand Down
50 changes: 50 additions & 0 deletions cutlass/coord.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,56 @@ struct Coord {

////////////////////////////////////////////////////////////////////////////////////////////////////

/// Scalar multiplication
template <typename T, int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator*(T s, Coord<Rank, Index> coord) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] *= s;
}
return coord;
}

/// Scalar multiplication
template <typename T, int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator*(Coord<Rank, Index> coord, T s) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] *= s;
}
return coord;
}

/// Scalar division
template <typename T, int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator/(T s, Coord<Rank, Index> coord) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] = s / coord[i];
}
return coord;
}

/// Scalar division
template <typename T, int Rank, typename Index>
CUTLASS_HOST_DEVICE
Coord<Rank, Index> operator/(Coord<Rank, Index> coord, T s) {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < Rank; ++i) {
coord[i] /= s;
}
return coord;
}

////////////////////////////////////////////////////////////////////////////////////////////////////
//
// Integer-valued make_Coord
//
////////////////////////////////////////////////////////////////////////////////////////////////////

/// Helper to make a 2-element coordinate
CUTLASS_HOST_DEVICE
Coord<1> make_Coord(int _0) {
Expand Down
18 changes: 2 additions & 16 deletions cutlass/cutlass.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
////////////////////////////////////////////////////////////////////////////////////////////////////

#define CUTLASS_MAJOR 1
#define CUTLASS_MINOR 1
#define CUTLASS_MINOR 2
#define CUTLASS_PATCH 0
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)

Expand All @@ -49,21 +49,7 @@

#define CUTLASS_ASSERT(x) assert(x)

// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
#if defined(__CUDA_ARCH__)
#if defined(_MSC_VER)
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
#define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
#else
#define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
#define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
#endif
#else
#define CUTLASS_PRAGMA_UNROLL
#define CUTLASS_PRAGMA_NO_UNROLL
#endif

#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
#include "cutlass/util/performance_tuning.h"

// A small helper class to dump a type at compile time
// Usage:: DumpType<Class>::Class
Expand Down
Loading

0 comments on commit 74df033

Please sign in to comment.