diff --git a/changelog.md b/CHANGELOG.md
similarity index 77%
rename from changelog.md
rename to CHANGELOG.md
index d9ff1d5dd5..c0606491ea 100644
--- a/changelog.md
+++ b/CHANGELOG.md
@@ -1,6 +1,22 @@
# NVIDIA CUTLASS Changelog
-## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11)
+
+## 1.1.0 (2018-09-19)
+ * Turing Features
+ * WMMA GEMM targeting TensorCores - INT8, INT4, 1-bit
+ * Batched Strided GEMM
+ * Threadblock rasterization strategies
+ * Improved performance for adverse problem sizes and data layouts
+ * Extended CUTLASS Core comonents
+ * Tensor views support arbitrary matrix and tensor layouts
+ * Zip iterators for structuring multiple data streams
+ * Enhanced CUTLASS utilities
+ * Reference code for tensor operations in host and device code
+ * Added HostMatrix<> for simplified matrix creation
+ * Examples
+ * Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
+
+## 1.0.1 (2018-06-11)
* Intra-threadblock reduction added for small threadblock tile sizes
* sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 5a53fae555..fdd51ae88e 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -55,11 +55,21 @@ endif()
find_package(CUDA)
find_package(Doxygen QUIET)
+###################################################################################################
+#
+# Configure CMake variables
+#
+###################################################################################################
+
+find_library(CUBLAS_LIBRARY cublas HINTS
+ ${CUDA_TOOLKIT_ROOT_DIR}/lib64
+ ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64)
+
# By default we want to build in Release mode to ensure that we're getting best performance
if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES))
set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE)
# We do support Debug or Release builds
- set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release")
+ set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release")
endif()
if(WIN32)
@@ -68,27 +78,59 @@ if(WIN32)
endif()
if (WIN32)
- # Enable more warnings and treat as errors
- string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
+ # Enable more warnings and treat as errors
+ string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX")
- # Disable excess x86 floating point precision that can lead to results being labeled incorrectly
- string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
+ # Disable warning on Unicode characters
+ string(APPEND NVCC_FLAGS " -Xcompiler /wd4819")
- # Verbose option
- if (${CUTLASS_NVCC_VERBOSE})
- string(APPEND NVCC_FLAGS " -v")
- endif()
+ # Disable excess x86 floating point precision that can lead to results being labeled incorrectly
+ string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict")
+
+ # Verbose option
+ if (${CUTLASS_NVCC_VERBOSE})
+ string(APPEND NVCC_FLAGS " -v")
+ endif()
endif(WIN32)
-# Configure CUDA options
-set(CUTLASS_NVCC_ARCHS "50;60;61;70" CACHE STRING "The SM architectures to build code for.")
-set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
+set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.")
+set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.")
+set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.")
+set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.")
+
+#
+# NOTE: running with asan and CUDA requires the following environment variable:
+#
+# ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0
+#
+# without the above environment setting, an error like the following may be generated:
+#
+# *** Error: Could not detect active GPU device ID [out of memory]
+# ...
+# ==9149==ERROR: LeakSanitizer: detected memory leaks
+# ...
+#
+if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer
+ string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer")
+ string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address")
+endif()
+###################################################################################################
+#
+# Configure CUDA build options
+#
+###################################################################################################
+
+# Set NVCC arguments
foreach(ARCH ${CUTLASS_NVCC_ARCHS})
- string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
+ if(CUTLASS_NVCC_EMBED_CUBIN)
+ string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}")
+ endif()
+ if(CUTLASS_NVCC_EMBED_PTX)
+ string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=compute_${ARCH}")
+ endif()
endforeach()
-
if (CUTLASS_NVCC_KEEP)
string(APPEND NVCC_FLAGS " -keep")
endif()
@@ -99,11 +141,8 @@ else()
string(APPEND NVCC_FLAGS " -lineinfo")
endif()
-if (UNIX)
- string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion")
-endif()
-
string(APPEND NVCC_FLAGS_DEBUG " -g")
+string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3")
string(APPEND NVCC_FLAGS_RELEASE " -O3")
# define NDEBUG for release mode to disable assertions
@@ -111,11 +150,13 @@ string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG")
if (CUTLASS_NATIVE_CUDA)
set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}")
- set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}")
+ set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${NVCC_FLAGS_RELWITHDEBINFO}")
+ set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}")
else()
set(CUDA_NVCC_FLAGS ${NVCC_FLAGS})
set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG})
+ set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${NVCC_FLAGS_RELWITHDEBINFO})
set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE})
endif()
@@ -128,6 +169,11 @@ file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h)
file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h)
file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h)
file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h)
+###################################################################################################
+#
+# Define build targets
+#
+###################################################################################################
source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM})
source_group("cutlass\\util" FILES ${CUTLASS_UTIL})
@@ -156,9 +202,9 @@ add_custom_target(cutlass_ide SOURCES
if (DOXYGEN_FOUND)
# DOT is available. Enable graph generation in the documentation
if (DOXYGEN_DOT_EXECUTABLE)
- set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
+ set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.")
else()
- set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
+ set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE)
endif()
if (CUTLASS_ENABLE_DOXYGEN_DOT)
@@ -177,6 +223,5 @@ if (DOXYGEN_FOUND)
)
endif()
-
-#add_subdirectory(examples/gemm)
add_subdirectory(tools)
+add_subdirectory(examples)
diff --git a/CUTLASS.md b/CUTLASS.md
new file mode 100644
index 0000000000..7dea0f3729
--- /dev/null
+++ b/CUTLASS.md
@@ -0,0 +1,311 @@
+
+
+# CUTLASS
+
+This document is intended to accompany the CUTLASS source code, to describe the interaction between
+CUTLASS core components, and to identify their role in implementing GEMM computations efficiently in CUDA.
+
+1. [Design Patterns](#S-design-patterns)
+2. [General Matrix Multiply](#S-general-matrix-multiply)
+3. [Core Components](#S-core-components)
+4. [Utilities](#S-utilities)
+
+# 1. Design Patterns
+
+CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a
+flexible composition that an be easily applied to solve new problems related to Deep Learning and
+linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given
+a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several
+design patterns are necessary to yield a composable structure while also satisfying these performance
+objectives. This section is intended to provide more detail.
+
+* [Sequencing and Nesting](#S-patterns-sequencing-nesting)
+* [Tiles and Iterators](#S-patterns-tiles-iterators)
+* [Host-side Params](#S-patterns-host-side-params)
+* [Composable Shared Memory](#S-patterns-composable-shared-memory)
+
+## Sequencing and Nesting of Collective Primitives
+
+CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks.
+
+## Tiles and Iterators
+
+Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants
+specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_.
+
+_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual
+elements in memory as well as traversing over a collection. GEMM kernels in CUTLASS depend on accessing
+a sequence of tiles from global memory, from shared memory, and in registers. Consequently, _tile iterators_
+are prevalent throughout the CUTLASS implementation.
+
+The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
+
+## Host-side Params structure
+
+Several CUTLASS template classes exhibit a pattern in which problem-specific internal state is known at kernel launch time and remains invariant throughout the execution of a kernel. For example, tile iterators compute several offsets based on the strides of the input tensor that is added to an internal pointer when loading the elements of a tile. These are computed from the tensor stride and never updated; the per-thread internal state consists only of the internal global memory pointer.
+
+CUTLASS can take advantage of this CUDA grid-invariant property by constructing the object in host code and passing a composed parameters structure to the kernel. This confers two benefits: (1.) invariant state is held in constant memory, and (2.) there is no overhead to compute the initial state by each thread.
+
+The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument.
+
+For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`.
+
+
+## Composable shared memory allocation
+
+Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm introduced by [CUB](https://nvlabs.github.io/cub/) to define composed structures for storing data intended to be held in shared memory. Any object requiring shared memory storage for itself or its data members should define a child structure called SharedStorage. This holds data needed by the class and also instantiates SharedStorage objects for each data member.
+
+To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes of child objects are known to be non-overlapping, unions may be used to alias multiple SharedStorage objects to the same shared memory region and reduce overall SMEM capacity.
+
+## Loop Unrolling
+
+CUTLASS requires tiles of data to be stored in registers for high-bandwidth access. Simultaneously, high-throughput math instructions
+must be issued concurrently with memory instructions to hide latency with relatively few concurrent threads. These objectives are
+achieved by unrolling loops whose iteration counts are known at compile time.
+
+Consequently, most loops within the CUTLASS GEMM implementation are specified by constant values and template arguments. The CUDA compiler
+is able to unroll the loop bodies, map array elements to registers, and construct an efficient instruction schedule.
+
+## Templates
+
+CUDA C++ templates and modern generic programming techniques enable CUTLASS device code to span a large design space.
+
+This design space includes:
+* Mixed precision arithmetic and data storage
+* Kernels specialized for layout and problem size
+* Support for kernel fusion
+
+Moreover, templates provided a structured approach to collecting compile-time constants such as tile dimensions. These
+must be template arguments to target static array allocation and take advantage of loop unrolling, constant folding,
+and function inlining.
+
+# 2. General Matrix Multiply
+
+The following figure illustrates the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a level within the memory hierarchy, becoming increasingly finer moving left to right.
+
+
+
+## Threadblock-level GEMM
+
+The CUTLASS GEMM kernel partitions the _C_ matrix into a 2D tiling of threadblocks.
+Each threadblock computes a matrix product whose outer dimensions _M_ and _N_ are compile-time constants. The
+GEMM's _K_ dimension is partitioned into tiles and iterated over by the GEMM _mainloop_. The shape of the matrix
+multiply operation performed by each iteration of the mainloop is referred to as _OutputTile_.
+
+The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative
+access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular
+buffer in shared memory is performed by a _GlobalLoadIterator_.
+
+**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers.
+
+The Global Load Stream template contains members defined by the following templates:
+
+* [GemmGlobalIteratorAb](cutlass/gemm/gemm_global_tile.h)
+* [Transformer](cutlass/convert.h)
+* [GemmSharedStoreTileAb](cutlass/gemm/gemm_shared_tile.h)
+
+## Warp-level GEMM
+
+The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product.
+Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores.
+
+[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results.
+
+* [GemmSharedLoadTile{A,B}](cutlass/gemm/gemm_shared_tile.h)
+
+**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores.
+
+* [WMMA Multiply Add](cutlass/gemm/wmma_gemm_multiply_add.h)
+
+## Thread-level GEMM
+
+SGEMM, IGEMM, HGEMM, and DGEMM are computed by SIMT math instructions issued by thread-level matrix multiply
+procedures.
+
+* [ThreadMultiplyAdd](cutlass/gemm/thread_multiply_add.h)
+* [IGEMM specialization](cutlass/gemm/igemm_multiply_add.h)
+* [HGEMM specialization](cutlass/gemm/hgemm_multiply_add.h)
+
+## Epilogue
+
+The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components:
+
+1. [Transformer](cutlass/convert.h) for converting the data types of accumulator elements
+2. [GemmSharedStoreTileD](cutlass/gemm/gemm_shared_tile.h) to store to shared memory specialized to the accumulator layout.
+3. [GemmSharedLoadTileD](cutlass/gemm/gemm_shared_tile.h) to load the data from shared memory.
+4. [GemmGlobalIteratorC](cutlass/gemm/gemm_global_tile.h) to load a tile from global memory.
+5. A [functor](cutlass/gemm/linear_scaling.h) to compute an element-wise operation on the matrix product and source data (such as alpha*AB+beta*C).
+6. [GemmGlobalIteratorD](cutlass/gemm/gemm_global_tile.h) to write the output to global memory.
+
+## GEMM Traits
+
+[**cutlass::gemm::GemmTraits**](cutlass/gemm/gemm_traits.h) collects the structural properties of a complete GEMM computation into a single template class. As a result, the Traits classes encapsulate the the iterators and transformers for all supported GEMM operands and layouts. Low-level details needed by Traits (such as scalar types for operands, thread-block tile size, number of scalar elements per memory access within each phase, number of stages in shared memory, as well as other implementation-specific properties of the GEMM computation) are specified in class [**cutlass::gemm::GemmConfig**](cutlass/gemm/gemm_config.h).
+
+
+# 3. Core Components
+
+CUTLASS GEMM kernels are implemented by a set of Core components for interacting with mathematical tensor and matrix
+objects as well as constructing efficient CUDA kernels.
+
+* [Tensor views](#S-core-tensor-views)
+* [Shape](#S-core-shape)
+* [Tile structure](#S-core-tile-structure)
+* [Fragment](#S-core-fragment)
+* [Predicate vector](#S-core-predicate-vector)
+
+## Tensor View
+
+Matrices and tensors are typically represented as n-D arrays held in linear memory with a single base pointer and a stride vector. Element _i_ of the stride vector indicates the offset in linear memory between consecutive elements in dimension i. Consequently, the linear offset for an arbitrary element specified as an n-tuple may be computed as the dot product of the coordinate and the stride vector.
+
+CUTLASS provides abstractions for interacting with multidimension tensors in device memory.
+Consequently, we define a hierarchy of pointer-like types for referencing tensors.
+
+`T *` - raw pointer to elements of type T
+
+`cutlass::TensorRef` - reference to a tensor of elements of type T and given rank. Includes a mapping function and associated stride vector for accessing elements in linear memory.
+
+`cutlass::TensorView` - extends `TensorRef<>` by adding bounds information. This is a complete mathematical object which may be used as the argument to CUTLASS functions.
+
+The above provide an identity maping of a logical index space to linear memory. An element
+at logical coordinate X has an offset computed as follows:
+```
+offset = dot(X, stride)
+```
+where `dot()` computes the inner product of X and a vector of "strides."
+
+CUTLASS 1.1 introduces a mapping function and an additional "storage rank" to offer a flexible way to
+map the logical index space of the tensor to memory. The mapping function maps a coordinate
+of rank _R_ to an index space of rank _S_. The linear offset is computed as:
+```
+offset = dot( MapFunc(X), stride )
+```
+where stride is a vector of rank _S_.
+
+CUTLASS kernels make extensive use of vectorization of memory accesses for efficiency and
+correctness. Consequently, we enforce a constraint on the strides used by mapping functions
+such that:
+
+1. The "fastest-changing" stride is always 1 thereby mandating that consecutive elements in
+ that rank are consecutive in linear memory.
+
+2. The fastest changing rank is always last in the stride vector and not explicitly stored.
+
+Thus, the stride vector used by mapping functions has length of one fewer than the rank of the
+storage tensor. These constraints are consistent with the BLAS interface of passing matrices as
+a tuple consisting of a pointer and a "leading dimension." In fact, these are rank=2 tensors
+whose fastest changing dimension is 1, and only the strided dimension is explicitly represented.
+
+A typical mapping function might simply map the rows and columns of a matrix, a rank=2 tensor,
+to linear memory such that (1.) elements in the same column are consecutive in memory
+(column-major), or (2.) elements in the same row are consecutive (row-major). These can be
+accomplished by two different mapping functions whose stride vector is length=2. The first
+element is the "leading dimension."
+
+The requirement that the fastest-changing stride always be of unit size need not be a limitation.
+To implement "sparse" computations or matrix operations in which matrix elements have arbitrary
+stride along both row and column, define a mapping function whose storage rank is 3. This permits
+two elements of the stride vector to have a non-unit value.
+
+`cutlass::TensorView<>` extends this concept by including a size vector to specify the bounds of
+the index space. The value of each coordinate in the size vector defines the half-open range of
+indices whose smallest value is zero.
+
+## Shape
+
+To avoid complicated template metaprogramming, CUTLASS targets fixed compile-time tile sizes specified
+by a four-dimensional template `cutlass::Shape<>`. This defines the following dimensions, mirroring
+the NHWC tensor format used for convolution in Deep Learning frameworks.
+
+- `D`: depth of tensor
+- `H`: first strided dimension
+- `W`: contiguous sequence of tensor elements
+- `C`: number of channels, usually used for vectorized access
+
+Template specializations of `Shape` appear as arguments to numerous dependent template classes which
+must specify compile-time constant tile sizes.
+
+## Tile Structure
+
+Tiled structures express an arrangement of data in memory as well as a logical mapping of concurrent CUDA
+threads to the problem space. For example, the CUTLASS GEMM
+
+Tiled structures can be defined using the `cutlass::TileTraits<>` concept which defines the following
+members. Collectively, these members offer a flexible way to define a 4-D subpartition of an integer
+lattice, partition its elements among a collection of threads, and map each unique thread ID to a unique
+offset.
+
+- _Tile_ (concept `Shape<>`) - describes the dimensions of the tile in terms of scalar elements
+- _Delta_ (concept `Shape<>`) - describes the distance along each logical dimension between items
+- _Iterations_ (concept `Shape<>`) - describes the number of items along each logical dimension
+- _ThreadOffset_ (concept _functor_) - implements `Coord<4> operator()() const` to determine a thread's
+ initial offset in the logical 4-D coordinate space
+
+The following figure illustrates the CUTLASS tile structure. The overall shape, 16-by-16, is partitioned into
+vectors of length two among 32 threads. The elements stored by thread 9 are highlighted.
+
+
+
+The `cutlass::TileTraits<>` definition that describes this arrangement may be defined as follows:
+
+```
+struct ExampleTileTraits {
+
+ /// Overall shape of tile
+ typedef Shape<1, 16, 16, 1> Tile;
+
+ /// Distance along each dimension of accesses
+ typedef Shape<1, 4, 1, 1> Delta;
+
+ /// Number of memory accesses performed by each thread
+ typedef Shape<1, 4, 1, 1> Iterations;
+
+ /// Offset function - maps each thread to a unique starting offset within the 4D tile
+ struct ThreadOffset {
+
+ CUTLASS_DEVICE Coord<4> operator()() const {
+
+ typdef Shape<1, 16, 8, 2> Vectorized;
+
+ return make_Coord(
+ 0, // depth "D" dimension
+ threadIdx.x / Vectorized::kW, // horisontal "H" dimension - first strided dimension
+ threadIdx.x % Vectorized::kW, // vertical "W" dimension - contiguous dimension
+ 0
+ );
+ }
+ };
+};
+```
+
+## Tile Iterator
+
+The iterator design pattern provides an abstraction for accessing the items in a collection in sequence. Basic
+operators defined by iterators consist of accessing an item - either a load or store - followed by traversal to
+the next item in sequence.
+
+
+
+To offer a generic solution that spans numerous data types and layouts, CUTLASS defines the _TileIterator_ concept.
+This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory.
+
+The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h).
+
+## Fragment
+
+A fragment is analogous to `std::array<>` in that it is a constant-sized array of elements. Typically backed by storage in the SM's register file, CUTLASS `Fragment<>` objects are used to store tiles. For threadblock- and warp-scope operations, the contents of these tiles are distributed across the partipcipating threads. In such cases, a thread's `Fragment<>` contains the part of the tile held by that thread.
+
+## Predicate Vector
+
+SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions.
+
+CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking.
+
+
+# 4. Utilities
+
+CUTLASS implements efficient matrix multiply computations on GPUs. It is accompanied by an extensive utility
+framework offering features such as:
+
+* [cutlass::half_t](tools/util/half.h) - a host-side half-precision type
+* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS
+* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h)
diff --git a/Doxyfile b/Doxyfile
index 51cec529b3..1d96f37708 100644
--- a/Doxyfile
+++ b/Doxyfile
@@ -58,7 +58,7 @@ PROJECT_LOGO =
# entered, it will be relative to the location where doxygen was started. If
# left blank the current directory will be used.
-OUTPUT_DIRECTORY = docs
+OUTPUT_DIRECTORY = doxygen
# If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub-
# directories (in 2 levels) under the output directory of each output format and
diff --git a/README.md b/README.md
index 56473a2861..c53a42f4bc 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,10 @@

-# CUTLASS 1.0
+# CUTLASS 1.1
-_CUTLASS 1.0.1 - June 2018_
+_CUTLASS 1.1.0 - September 2018_
-CUTLASS 1.0 is a collection of CUDA C++ template abstractions for implementing
+CUTLASS 1.1 is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA.
It incorporates strategies for hierarchical decomposition and data movement similar
to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into
@@ -22,14 +22,27 @@ point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targe
the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture
and beyond.
-CUTLASS 1.0 has changed substantially from our preview release described in
-the [CUTLASS Parallel For All](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda)
-post. We have decomposed the structure of the GEMM computation into deeper, structured
-primitives for loading data, computing predicate masks, streaming data at each level of
-the GEMM hierarchy, and updating the output matrix.
-
-CUTLASS 1.0 is described in the [Doxygen documentation](https://nvidia.github.io/cutlass)
-and our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
+CUTLASS 1.1 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying
+[Doxygen documentation](https://nvidia.github.io/cutlass).
+We describe the structure of an efficient GEMM in our talk at the
+[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf).
+
+# What's New in CUTLASS 1.1
+
+* [CUTLASS Documentation](CUTLASS.md)
+* [Examples](examples/)
+ * Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM
+* Turing Features
+ * [WMMA GEMM targeting TensorCores](tools/test/unit/gemm/wmma_integer_gemm.cu) - INT8, INT4, 1-bit
+* [Batched Strided GEMM](tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu)
+* [Threadblock rasterization strategies](tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu)
+ * Improved performance for adverse problem sizes and data layouts
+* Extended CUTLASS Core comonents
+ * Tensor views support arbitrary matrix and tensor layouts
+ * Zip iterators for structuring multiple data streams
+* Enhanced CUTLASS utilities
+ * [Reference implementations](tools/util/reference) for tensor operations in [host](tools/util/reference/host) and [device](tools/util/reference/device) code
+ * Added `HostMatrix<>` for simplified matrix creation
# Performance
@@ -39,11 +52,11 @@ CUTLASS primitives are very efficient. When used to construct device-wide GEMM
they exhibit performance comparable to cuBLAS for scalar GEMM
computations. The above figure shows CUTLASS performance relative to cuBLAS
for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU
-when compiled with CUDA 9.2.
+when compiled with CUDA 10.0.
# Compatibility
-CUTLASS requires CUDA 9 and performs best with [CUDA 9.2 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later.
+CUTLASS requires CUDA 9 but performs best with [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later.
|**Operating System** | **Compiler** |
|-----------------|----------|
@@ -63,7 +76,7 @@ any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU.
|NVIDIA Tesla P100|
|NVIDIA Tesla V100|
|NVIDIA TitanV|
-
+|NVIDIA GeForce RTX 2080 TI, 2080, 2070|
# Building CUTLASS
@@ -79,7 +92,7 @@ $ git submodule update --init --recursive
```
CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels
-for CUDA architecture versions 5.0, 6.0, 6.1 and 7.0. To reduce compile time you can specify
+for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify
the architectures to build CUTLASS for by changing the CMake configuration setting
`CUTLASS_NVCC_ARCHS`.
@@ -107,13 +120,12 @@ $ ./tools/test/unit/cutlass_unit_test
...
...
[----------] Global test environment tear-down
-[==========] 481 tests from 24 test cases ran. (5954 ms total)
-[ PASSED ] 481 tests.
+[==========] 946 tests from 57 test cases ran. (10812 ms total)
+[ PASSED ] 946 tests.
```
All tests should pass, though the exact number of tests may vary over time.
-
# Project Structure
CUTLASS is arranged as a header-only library with several example test programs
@@ -128,28 +140,41 @@ templates in the cutlass/gemm directory.
```
cutlass/
- gemm/
- util/
-
+ gemm/
+ util/
+
```
Several tools and test programs are also distributed with the CUTLASS library. They are
contained in the following directories.
```
+examples/
+ 00_basic_gemm/
+ 01_tensor_view/
+ 02_cutlass_utilities/
+ 03_batched_gemm/
+ 04_tile_iterator/
+ 05_wmma_gemm/
tools/
- test/
- unit/
- core/
- gemm/
- perf/
- util/
-
+ test/
+ unit/
+ core/
+ gemm/
+ perf/
+ util/
+ reference/
+ device/
+ host/
+
```
The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate
basic usage of Core API components and complete tests of the CUTLASS GEMM computations.
+The `tools/util` directory contains CUTLASS utilities including reference implementations of GEMM and
+several element-wise tensor operations.
+
# Performance Profiling
The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels.
diff --git a/clang-format.sh b/clang-format.sh
deleted file mode 100755
index b2570d9147..0000000000
--- a/clang-format.sh
+++ /dev/null
@@ -1,17 +0,0 @@
-#!/bin/bash
-
-set -e
-
-function formatFiles {
- for f in `find "$1" -type f -name "*.$2"` ; do
- COMMAND="clang-format -i $f"
- echo $COMMAND
- $COMMAND
- done
-}
-
-formatFiles "cutlass" "h"
-formatFiles "tools/test" "h"
-formatFiles "tools/test" "cpp"
-formatFiles "tools/util" "h"
-
diff --git a/cutlass/convert.h b/cutlass/convert.h
index 933d68a82a..b4d0f8eddb 100644
--- a/cutlass/convert.h
+++ b/cutlass/convert.h
@@ -28,7 +28,7 @@
*/
#pragma once
-#include
+#include "cutlass/fragment.h"
namespace cutlass {
diff --git a/cutlass/coord.h b/cutlass/coord.h
index 431c9bf1a0..625a22723d 100644
--- a/cutlass/coord.h
+++ b/cutlass/coord.h
@@ -28,7 +28,8 @@
#pragma once
-#include
+#include "cutlass/cutlass.h"
+#include "cutlass/util/platform.h"
namespace cutlass {
@@ -44,20 +45,27 @@ struct Identity {
////////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically-sized array specifying Coords within a tensor
-template
+template
struct Coord {
//
// Type and constant definitions
//
- static int const N = N_;
+ /// Number of elements in Coord
+ static int const kRank = Rank_;
+
+ /// Number of elements in Coord, aliased for compatibility
+ static int const N = Rank_;
+
+ /// Index type used to store elements
+ typedef Index_ Index;
//
// Data members
//
/// Indices
- int idx[N];
+ Index idx[kRank];
//
// Methods
@@ -65,25 +73,72 @@ struct Coord {
/// Default ctor initializes uniformly
CUTLASS_HOST_DEVICE
- Coord(int value = 0) {
- for (int i = 0; i < N; ++i) {
+ Coord(Index value = 0) {
+ for (int i = 0; i < kRank; ++i) {
idx[i] = value;
}
}
/// Constructs from an array of integers
CUTLASS_HOST_DEVICE
- Coord(int _idx[]) {
- for (int i = 0; i < N; ++i) {
+ Coord(Index _idx[]) {
+ for (int i = 0; i < kRank; ++i) {
idx[i] = _idx[i];
}
}
+ /// Constructs from an array of integers
+ CUTLASS_HOST_DEVICE
+ Coord(Coord const &coord) {
+ for (int i = 0; i < kRank; ++i) {
+ idx[i] = coord[i];
+ }
+ }
+
+ /// Returns a slice of the Coord which may be larger or smaller in rank
+ /// than this.
+ template
+ CUTLASS_HOST_DEVICE
+ Coord slice(int start = 0, Index identity = 0) const {
+ Coord result;
+ for (int i = 0; i < Slice; ++i) {
+ if (i + start < kRank) {
+ slice[i] = idx[i + start];
+ }
+ else {
+ slice[i] = identity;
+ }
+ }
+ return result;
+ }
+
+ /// Returns true if Coord is non-zero.
+ CUTLASS_HOST_DEVICE
+ operator bool() const {
+ for (int i = 0; i < kRank; ++i) {
+ if (idx[i]) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ /// Returns true if Coord is uniformly zero.
+ CUTLASS_HOST_DEVICE
+ bool operator!() const {
+ for (int i = 0; i < kRank; ++i) {
+ if (idx[i]) {
+ return false;
+ }
+ }
+ return true;
+ }
+
/// Element-wise addition
CUTLASS_HOST_DEVICE
Coord operator+(Coord const& b) const {
Coord c;
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] + b.idx[i];
}
return c;
@@ -93,7 +148,7 @@ struct Coord {
CUTLASS_HOST_DEVICE
Coord operator-(Coord const& b) const {
Coord c;
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] - b.idx[i];
}
return c;
@@ -103,7 +158,7 @@ struct Coord {
CUTLASS_HOST_DEVICE
Coord operator*(Coord const& b) const {
Coord c;
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] * b.idx[i];
}
return c;
@@ -113,7 +168,7 @@ struct Coord {
CUTLASS_HOST_DEVICE
Coord operator/(Coord const& b) const {
Coord c;
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
c.idx[i] = idx[i] / b.idx[i];
}
return c;
@@ -122,7 +177,7 @@ struct Coord {
/// In-place addition
CUTLASS_HOST_DEVICE
Coord& operator+=(Coord const& b) {
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
idx[i] += b.idx[i];
}
return *this;
@@ -131,7 +186,7 @@ struct Coord {
/// In-place subtraction
CUTLASS_HOST_DEVICE
Coord& operator-=(Coord const& b) {
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
idx[i] -= b.idx[i];
}
return *this;
@@ -140,7 +195,7 @@ struct Coord {
/// In-place multiplication
CUTLASS_HOST_DEVICE
Coord& operator*=(Coord const& b) {
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
idx[i] *= b.idx[i];
}
return *this;
@@ -149,22 +204,22 @@ struct Coord {
/// In-place division
CUTLASS_HOST_DEVICE
Coord& operator/=(Coord const& b) {
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
idx[i] /= b.idx[i];
}
return *this;
}
/// Member access operator
- CUTLASS_HOST_DEVICE int& operator[](int dim) { return idx[dim]; }
+ CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; }
/// Member access operator
- CUTLASS_HOST_DEVICE int const& operator[](int dim) const { return idx[dim]; }
+ CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; }
/// Computes the dot product of two Coord instances
template
CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
sum += idx[i] * b.idx[i];
}
return sum;
@@ -174,7 +229,7 @@ struct Coord {
template
CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
T sum = T(0);
- for (int i = 0; i < N; ++i) {
+ for (int i = 0; i < kRank; ++i) {
sum += idx[i] * b.idx[i];
}
return sum;
@@ -182,29 +237,29 @@ struct Coord {
/// Gets the index of a given Coord element
template
- CUTLASS_HOST_DEVICE int& at() {
+ CUTLASS_HOST_DEVICE Index& at() {
return idx[Dim];
}
/// Access via index; may limit unrolling potential
CUTLASS_HOST_DEVICE
- int& at(int dim) { return idx[dim]; }
+ Index& at(int dim) { return idx[dim]; }
/// Gets the index of a given Coord element
template
- CUTLASS_HOST_DEVICE int const& at() const {
+ CUTLASS_HOST_DEVICE Index const& at() const {
return idx[Dim];
}
/// Access via index; may limit unrolling potential
CUTLASS_HOST_DEVICE
- int const& at(int dim) const { return idx[dim]; }
+ Index const& at(int dim) const { return idx[dim]; }
/// Determines if two Coord<> objects are equal
CUTLASS_HOST_DEVICE
- bool operator==(Coord const& b) const {
+ bool operator==(Coord const& b) const {
bool equal = true;
- for (int i = 0; equal && i < N; ++i) {
+ for (int i = 0; equal && i < kRank; ++i) {
equal = (idx[i] == b.idx[i]);
}
return equal;
@@ -212,12 +267,12 @@ struct Coord {
/// Not equal
CUTLASS_HOST_DEVICE
- bool operator!=(Coord const& b) const { return !(*this == b); }
+ bool operator!=(Coord const& b) const { return !(*this == b); }
/// Clamps a coordinate to a range specified by maximum and minimum values
CUTLASS_HOST_DEVICE
- Coord& clamp(Coord const& max, Coord const& min = Coord()) {
- for (int i = 0; i < N; ++i) {
+ Coord& clamp(Coord const& max, Coord const& min = Coord()) {
+ for (int i = 0; i < kRank; ++i) {
idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
}
return *this;
@@ -225,13 +280,35 @@ struct Coord {
/// Returns the product of all elements
CUTLASS_HOST_DEVICE
- int count() const {
- int product = idx[0];
- for (int i = 1; i < N; ++i) {
+ Index count() const {
+ Index product = idx[0];
+ for (int i = 1; i < kRank; ++i) {
product *= idx[i];
}
return product;
}
+
+ /// Less than operator
+ CUTLASS_HOST_DEVICE
+ bool operator<(Coord const &b) const {
+ for (int i = 0; i < kRank; ++i) {
+ if (!(idx[i] < b[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /// Less than or equals operator
+ CUTLASS_HOST_DEVICE
+ bool operator<=(Coord const &b) const {
+ for (int i = 0; i < kRank; ++i) {
+ if (!(idx[i] <= b[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -266,21 +343,10 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
////////////////////////////////////////////////////////////////////////////////////////////////////
-/// Getter
-CUTLASS_HOST_DEVICE
-Coord<2> get_Coord_hw(Coord<3> const& coord) { return make_Coord(coord[1], coord[2]); }
-
-/// Getter
-CUTLASS_HOST_DEVICE
-Coord<2> get_Coord_hw(Coord<4> const& coord) { return make_Coord(coord[1], coord[2]); }
-
-/// Getter
-CUTLASS_HOST_DEVICE
-Coord<3> get_Coord_hwc(Coord<4> const& coord) { return make_Coord(coord[1], coord[2], coord[3]); }
-
-/// Getter
-CUTLASS_HOST_DEVICE
-Coord<3> get_Coord_dhw(Coord<4> const& coord) { return make_Coord(coord[0], coord[1], coord[2]); }
+template
+CUTLASS_HOST_DEVICE Coord<3> make_Coord_from_shape() {
+ return make_Coord(Shape_::kD, Shape_::kH, Shape_::kW);
+}
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/cutlass/core_io.h b/cutlass/core_io.h
index cceea4c06d..849a7613f4 100644
--- a/cutlass/core_io.h
+++ b/cutlass/core_io.h
@@ -22,8 +22,6 @@
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
-#pragma once
-
/*! \file
\brief Helpers for printing cutlass/core objects
*/
@@ -33,12 +31,96 @@
#include
#include
-#include
+#include "cutlass/coord.h"
+#include "cutlass/vector.h"
+
+namespace cutlass {
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
template
-std::ostream& operator<<(std::ostream& out, cutlass::Coord const& coord) {
+std::ostream& operator<<(std::ostream& out, Coord const& coord) {
for (int i = 0; i < Rank; ++i) {
out << (i ? ", " : "") << coord.idx[i];
}
return out;
}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Helper to enable formatted printing of CUTLASS scalar types to an ostream
+template
+struct ScalarIO {
+
+ /// Value to print
+ T value;
+
+ /// Default ctor
+ ScalarIO() { }
+
+ /// Constructs from a value
+ ScalarIO(T value): value(value) {}
+};
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Default printing to ostream
+template
+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) {
+ return out << scalar.value;
+}
+
+/// Printing to ostream of int8_t as integer rather than character
+template <>
+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) {
+ return out << int(scalar.value);
+}
+
+/// Printing to ostream of uint8_t as integer rather than character
+template <>
+inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) {
+ return out << unsigned(scalar.value);
+}
+
+/// Printing to ostream of vector of 1b elements
+template <>
+inline std::ostream &operator<<(
+ std::ostream &out,
+ ScalarIO > const &scalar) {
+
+ for (int i = 0; i < 32; i++) {
+ out << int(scalar.value[i]);
+ out << ((i != 31) ? ", " : "");
+ }
+ return out;
+}
+
+/// Printing to ostream of vector of 4b signed integer elements
+template <>
+inline std::ostream &operator<<(
+ std::ostream &out,
+ ScalarIO > const &scalar) {
+
+ for (int i = 0; i < 8; i++) {
+ out << int(scalar.value[i]);
+ out << ((i != 7) ? ", " : "");
+ }
+ return out;
+}
+
+/// Printing to ostream of vector of 4b unsigned integer elements
+template <>
+inline std::ostream &operator<<(
+ std::ostream &out,
+ ScalarIO > const &scalar) {
+
+ for (int i = 0; i < 8; i++) {
+ out << unsigned(scalar.value[i]);
+ out << ((i != 7) ? ", " : "");
+ }
+ return out;
+}
+
+///////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace cutlass
diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h
index 19600ec8f7..15ea83c014 100644
--- a/cutlass/cutlass.h
+++ b/cutlass/cutlass.h
@@ -32,8 +32,8 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CUTLASS_MAJOR 1
-#define CUTLASS_MINOR 0
-#define CUTLASS_PATCH 1
+#define CUTLASS_MINOR 1
+#define CUTLASS_PATCH 0
#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
#ifdef __NVCC__
@@ -47,7 +47,9 @@
// CUTLASS_DEVICE is an error if not compiling device code
#endif
-// CUTLASS_PRAGMA_UNROLL inserts a CUTLASS_PRAGMA_UNROLL if supported by the compiler
+#define CUTLASS_ASSERT(x) assert(x)
+
+// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
#if defined(__CUDA_ARCH__)
#if defined(_MSC_VER)
#define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
@@ -61,7 +63,22 @@
#define CUTLASS_PRAGMA_NO_UNROLL
#endif
-#define CUTLASS_ASSERT(x) assert(x)
+#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
+
+// A small helper class to dump a type at compile time
+// Usage:: DumpType::Class
+template
+struct DebugType {};
+
+template
+void DebugTypeFunc(T const& t) {
+ T::t;
+}
+
+// A small helper class to dump a compile time constant at compile time
+// Usage: DumpValue::kConstant
+template
+struct DebugValue {};
namespace cutlass {
diff --git a/cutlass/fragment.h b/cutlass/fragment.h
index 886b11405c..6a93d779c4 100644
--- a/cutlass/fragment.h
+++ b/cutlass/fragment.h
@@ -29,9 +29,9 @@
#pragma once
#include
-#include
-#include
-#include
+#include "cutlass/shape.h"
+#include "cutlass/util/cutlass_math.h"
+#include "cutlass/vector.h"
namespace cutlass {
@@ -72,7 +72,7 @@ provides access to element at (d, h, w, c)
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template
struct StorageType {
typedef uint64_t Type;
};
@@ -108,9 +108,11 @@ struct Fragment : public AlignedStruct {
typedef Element_ Element;
/// The number of elements.
static int const kElements = kElements_;
+ /// Alignment
+ static int const kAlignment = kAlignment_;
/// Clear a fragment.
- CUTLASS_DEVICE void clear() {
+ CUTLASS_HOST_DEVICE void clear() {
// Avoid element-wise access for sub 32b element type
if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
uint64_t* ptr = reinterpret_cast(storage);
@@ -135,14 +137,10 @@ struct Fragment : public AlignedStruct {
}
/// The accessor.
- CUTLASS_DEVICE Element& operator[](int i) {
- assert(i < kElements_);
- return reinterpret_cast(storage)[i];
- }
+ CUTLASS_HOST_DEVICE Element& operator[](int i) { return reinterpret_cast(storage)[i]; }
/// The accessor.
- CUTLASS_DEVICE Element const& operator[](int i) const {
- assert(i < kElements_);
+ CUTLASS_HOST_DEVICE Element const& operator[](int i) const {
return reinterpret_cast(storage)[i];
}
@@ -188,35 +186,35 @@ struct FragmentIterator {
/// Ctor.
template
- CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
+ CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
: pointer(reinterpret_cast(&fragment[offset])) {
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
}
/// The accessor.
- CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
+ CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
int const imm = ComputeOffsetFromStrides::get(d, h, w, c);
return reinterpret_cast(pointer[imm]);
}
/// The accessor.
- CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
+ CUTLASS_HOST_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
int const imm = ComputeOffsetFromStrides::get(d, h, w, c);
return reinterpret_cast(pointer[imm]);
}
/// The accessor.
- CUTLASS_DEVICE AccessType const& operator[](int i) const {
+ CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
return reinterpret_cast(pointer[i * kElementsPerAccess]);
}
/// The accessor.
- CUTLASS_DEVICE AccessType& operator[](int i) {
+ CUTLASS_HOST_DEVICE AccessType& operator[](int i) {
return reinterpret_cast(pointer[i * kElementsPerAccess]);
}
/// Is the iterator valid?
- CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
+ CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
/// The pointer.
Element* pointer;
@@ -246,28 +244,28 @@ struct FragmentConstIterator {
/// Ctor.
template
- CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
+ CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
: pointer(reinterpret_cast(&fragment[offset])) {
static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
}
/// Create from non-constant FragmentIterator
- CUTLASS_DEVICE FragmentConstIterator(
+ CUTLASS_HOST_DEVICE FragmentConstIterator(
FragmentIterator const& rhs_)
: pointer(reinterpret_cast(rhs_.offset)) {}
/// The accessor.
- CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
+ CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
int const imm = ComputeOffsetFromStrides::get(d, h, w, c);
return reinterpret_cast(pointer[imm]);
}
/// The accessor.
- CUTLASS_DEVICE AccessType const& operator[](int i) const {
+ CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const {
return reinterpret_cast(pointer[i * kElementsPerAccess]);
}
/// Is the iterator valid?
- CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
+ CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
/// The pointer.
Element const* pointer;
diff --git a/cutlass/fragment_load_store.h b/cutlass/fragment_load_store.h
deleted file mode 100644
index a7d272e9e3..0000000000
--- a/cutlass/fragment_load_store.h
+++ /dev/null
@@ -1,135 +0,0 @@
-/***************************************************************************************************
- * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
- *
- * Redistribution and use in source and binary forms, with or without modification, are permitted
- * provided that the following conditions are met:
- * * Redistributions of source code must retain the above copyright notice, this list of
- * conditions and the following disclaimer.
- * * Redistributions in binary form must reproduce the above copyright notice, this list of
- * conditions and the following disclaimer in the documentation and/or other materials
- * provided with the distribution.
- * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
- * to endorse or promote products derived from this software without specific prior written
- * permission.
- *
- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
- * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
- * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
- * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
- * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
- * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
- * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
- * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- **************************************************************************************************/
-/*! \file
- \brief Defines accessors for loading and storing fragments to memory efficiently.
-*/
-#pragma once
-
-#include
-#include
-
-namespace cutlass {
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template
-struct FragmentLoad {};
-
-template
-struct FragmentLoad {
- /// The output type.
- typedef FragmentElement_ AccessType;
-
- /// The load function.
- static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
- value.load(&pointer[offset], kStride);
- }
-};
-
-template
-struct FragmentLoad {
- /// The output type.
- typedef typename Vectorize::Type AccessType;
-
- /// The load function.
- static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
- Load::load(value, pointer, offset);
- }
-};
-
-template
-struct FragmentStore {};
-
-template
-struct FragmentStore {
- /// The input type.
- typedef FragmentElement_ AccessType;
-
- /// The store function.
- static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
- value.store(&pointer[offset], kStride);
- }
-};
-
-template
-struct FragmentStore {
- /// The input type.
- typedef typename Vectorize::Type AccessType;
-
- /// The store function.
- static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
- Store::store(value, pointer, offset);
- }
-};
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-} /// namespace cutlass
diff --git a/cutlass/fragment_multiply_add.h b/cutlass/fragment_multiply_add.h
index 36a4d6f6a5..de2c8052fe 100644
--- a/cutlass/fragment_multiply_add.h
+++ b/cutlass/fragment_multiply_add.h
@@ -27,52 +27,59 @@
*/
#pragma once
-#include
+#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
+template < typename ScalarAlphaBeta_,
+ typename ScalarAccum_,
+ bool fragMul2 = true /*number of element per fragment is multiple of 2*/
+>
struct FragmentMultiplyAdd {
/// The shape of the instruction.
typedef Shape<1, 1, 1, 1> InstructionShape;
- /// The type for A.
- typedef Scalar_ ScalarA;
- /// The type for B.
- typedef Scalar_ ScalarB;
- /// The type for C and D.
- typedef Scalar_ ScalarC;
+ /// The type for alpha and beta
+ typedef ScalarAlphaBeta_ ScalarAlphaBeta;
+ /// The type for accumlator
+ typedef ScalarAccum_ ScalarAccum;
/// Ctor.
CUTLASS_DEVICE FragmentMultiplyAdd() {}
/// Multiply : d = a*b.
template
- CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) {
+ CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) {
+#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
- d[j] = a * b[j * kReduction + 0];
+ d[j] = b[j * kReduction + 0];
for (int k = 1; k < kReduction; ++k) {
- d[j] += a * b[j * kReduction + k];
+ d[j] += b[j * kReduction + k];
}
+ d[j] = a * ScalarAlphaBeta(d[j]);
}
+#endif
}
/// Multiply : d = a*b + c.
template
- CUTLASS_DEVICE void multiply_add(Scalar_ a,
+ CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a,
FragmentB_ const& b,
FragmentCd_ const& c,
FragmentCd_& d) {
+#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
for (int j = 0; j < FragmentCd_::kElements; ++j) {
- d[j] = a * b[j * kReduction + 0] + c[j];
+ d[j] = b[j * kReduction + 0];
for (int k = 1; k < kReduction; ++k) {
- d[j] += a * b[j * kReduction + k];
+ d[j] += b[j * kReduction + k];
}
+ d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]);
}
+#endif
}
};
@@ -80,15 +87,13 @@ struct FragmentMultiplyAdd {
#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
template <>
-struct FragmentMultiplyAdd {
+struct FragmentMultiplyAdd {
/// The shape of the instruction.
- typedef Shape<1, 1, 2, 1> InstructionShape;
- /// The type for A.
- typedef half ScalarA;
- /// The type for B.
- typedef half ScalarB;
- /// The type for C and D.
- typedef half ScalarC;
+ typedef Shape<1, 1, 1, 1> InstructionShape;
+ /// The type for alpha and beta
+ typedef half ScalarAlphaBeta;
+ /// The type for accumlator
+ typedef half ScalarAccum;
/// Ctor.
CUTLASS_DEVICE FragmentMultiplyAdd() {}
@@ -97,17 +102,19 @@ struct FragmentMultiplyAdd {
template
CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
-
- // Assemble a half2 from a.
- __half2 const a_half2 = __half2half2(a);
// The input.
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
- int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
+ // Assemble a half2 from a.
+ __half2 const a_half2 = __half2half2(a);
+
+ int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
+
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
+
for (int k = 1; k < kReduction; ++k) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
}
@@ -115,6 +122,7 @@ struct FragmentMultiplyAdd {
#endif
}
+
/// Multiply : d = a*b + c.
template
CUTLASS_DEVICE void multiply_add(half a,
@@ -122,17 +130,19 @@ struct FragmentMultiplyAdd {
FragmentCd_ const& c,
FragmentCd_& d) {
#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
- // Assemble a half2 from a.
- __half2 const a_half2 = __half2half2(a);
// The inputs.
__half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
__half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
// The output.
__half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
+ // Assemble a half2 from a.
+ __half2 const a_half2 = __half2half2(a);
+
int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
+
for (int k = 1; k < kReduction; ++k) {
d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
}
diff --git a/cutlass/gemm/clear_accumulators.h b/cutlass/gemm/clear_accumulators.h
index 441370f4c3..3a2f337525 100644
--- a/cutlass/gemm/clear_accumulators.h
+++ b/cutlass/gemm/clear_accumulators.h
@@ -27,7 +27,7 @@
*/
#pragma once
-#include
+#include "cutlass/vector.h"
namespace cutlass {
namespace gemm {
@@ -39,11 +39,12 @@ struct ClearAccumulators {
/// The shared storage.
struct SharedStorage {};
- /// Ctor.
- CUTLASS_DEVICE ClearAccumulators() {}
/// Ctor.
CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
+ /// Ctor.
+ CUTLASS_DEVICE ClearAccumulators() {}
+
/// Clear the fragment.
template
CUTLASS_DEVICE void clear(Fragment_& fragment) {
diff --git a/cutlass/gemm/dgemm_traits.h b/cutlass/gemm/dgemm_traits.h
index 0bbc2210bc..5c05590207 100644
--- a/cutlass/gemm/dgemm_traits.h
+++ b/cutlass/gemm/dgemm_traits.h
@@ -27,13 +27,13 @@
*/
#pragma once
-#include
-#include
-#include
-#include
-#include
-#include
-#include
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/gemm_epilogue.h"
+#include "cutlass/gemm/gemm_epilogue_traits.h"
+#include "cutlass/gemm/gemm_global_tile.h"
+#include "cutlass/gemm/gemm_shared_tile.h"
+#include "cutlass/gemm/gemm_traits.h"
+#include "cutlass/gemm/thread_multiply_add.h"
namespace cutlass {
namespace gemm {
@@ -41,10 +41,10 @@ namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
- /// The tile size for the GEMM KxNxM.
+ /// The tile size for threadblock-level GEMM (K-by-N-by-M).
typename OutputTile_,
- /// The number of accumulators per thread.
- typename AccumulatorsPerThread_,
+ /// Tile size for thread-level GEMM (K-by-N-by-M)
+ typename ThreadGemmShape_,
/// The number of scalars per LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of scalars per LDG for B.
@@ -62,7 +62,7 @@ struct DgemmConfig
/// The tile size for the GEMM KxNxM.
OutputTile_,
/// The functor to do the math in the main loop.
- ThreadMultiplyAdd, double, double, double>,
+ ThreadMultiplyAdd, double, double, double>,
/// The number of scalars per LDG for A.
kScalarsPerLdgA_,
/// The number of scalars per STS for A.
@@ -82,7 +82,14 @@ struct DgemmConfig
/// The number of scalars per LDS for D.
1,
/// The number of stages in shared memory.
- 2> {};
+ 2,
+ /// kResidueSeparate
+ false,
+ /// kResidueInPrologue
+ false,
+ /// kLaunchBounds
+ false
+ >{};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -91,12 +98,12 @@ template <
MatrixLayout::Kind kLayoutA_,
/// The layout for B.
MatrixLayout::Kind kLayoutB_,
- /// The output tile.
+ /// The tile size for threadblock-level GEMM (K-by-N-by-M)
typename OutputTile_ = Shape<8, 64, 128>,
/// The functor to use in the epilogue.
typename EpilogueFunctor_ = LinearScaling,
- /// The number of accumulators per thread.
- typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
+ /// Tile size for thread-level GEMM (K-by-N-by-M)
+ typename ThreadGemmShape_ = Shape<8, 8, 8>,
/// The number of doubles loaded in one LDG for A.
int kScalarsPerLdgA_ = 1,
/// The number of doubles loaded in one LDG for B.
@@ -105,7 +112,7 @@ template <
typename Index_ = int,
/// The DGEMM config.
typename GemmConfig_ =
- DgemmConfig,
+ DgemmConfig,
/// The traits class for the epilogue.
typename GemmEpilogueTraits_ =
SimplifiedGemmEpilogueTraits >
diff --git a/cutlass/gemm/fp16_sgemm_multiply_add.h b/cutlass/gemm/fp16_sgemm_multiply_add.h
new file mode 100644
index 0000000000..534b8c8998
--- /dev/null
+++ b/cutlass/gemm/fp16_sgemm_multiply_add.h
@@ -0,0 +1,83 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Template implementing matrix multiply-add operations on fragments.
+*/
+#pragma once
+
+#include "cutlass/fragment.h"
+#include "cutlass/gemm/thread_multiply_add.h"
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Template performing matrix multiply-add operation within a thread
+template
+struct ThreadMultiplyAdd {
+ /// The shape of the instruction.
+ typedef Shape<1, 1, 1, 1> InstructionShape;
+ /// The shape of a thread-leveel matrix multiply accumulate.
+ typedef ThreadGemmShape_ ThreadGemmShape;
+ /// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0
+ typedef ThreadGemmShape AccumulatorsPerThread;
+ /// The number of threads per warp.
+ typedef ThreadsPerWarp_ ThreadsPerWarp;
+ /// The number of accumulators per warp.
+ typedef typename ShapeMul::Shape AccumulatorsPerWarp;
+ /// The type for A. specialized to half
+ typedef half ScalarA;
+ /// The fragment for A.
+ typedef Fragment FragmentA;
+ /// The type for B. specialized to half
+ typedef half ScalarB;
+ /// The fragment for B.
+ typedef Fragment FragmentB;
+ /// The type for C and D. specialized to float
+ typedef float ScalarC;
+ /// The accumulators.
+ typedef Fragment Accumulators;
+
+ /// Ctor.
+ CUTLASS_DEVICE ThreadMultiplyAdd() {}
+
+ /// Multiply : d = a*b + c.
+ CUTLASS_DEVICE void multiply_add(FragmentA const& a,
+ FragmentB const& b,
+ Accumulators const& c,
+ Accumulators& d) {
+ for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
+ for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
+ d[j * AccumulatorsPerThread::kW + i] = static_cast(a[i]) * static_cast(b[j]) + c[j * AccumulatorsPerThread::kW + i];
+ }
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/fp16_sgemm_traits.h b/cutlass/gemm/fp16_sgemm_traits.h
new file mode 100644
index 0000000000..361186455b
--- /dev/null
+++ b/cutlass/gemm/fp16_sgemm_traits.h
@@ -0,0 +1,152 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defies structural properties of single-precision GEMM where any number of the input/output
+ could be fp16 or fp32. The accumulator type stays in fp32
+*/
+#pragma once
+
+#include "cutlass/gemm/gemm.h"
+#include "cutlass/gemm/gemm_epilogue.h"
+#include "cutlass/gemm/gemm_epilogue_traits.h"
+#include "cutlass/gemm/gemm_global_tile.h"
+#include "cutlass/gemm/gemm_shared_tile.h"
+#include "cutlass/gemm/gemm_traits.h"
+#include "cutlass/gemm/fp16_sgemm_multiply_add.h"
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// The tile size for the GEMM KxNxM.
+ typename OutputTile_,
+ /// Tile size for thread-level GEMM (K-by-N-by-M)
+ typename ThreadGemmShape_,
+ /// The type for A
+ typename ScalarA_,
+ /// The type for B
+ typename ScalarB_,
+ /// The type for C
+ typename ScalarC_,
+ /// The type for D
+ typename ScalarD_,
+ /// The number of scalars per LDG for A.
+ int kScalarsPerLdgA_ = 1,
+ /// The number of scalars per LDG for B.
+ int kScalarsPerLdgB_ = 1>
+struct Fp16SgemmConfig : public GemmConfig<
+ /// The scalar type for A.
+ ScalarA_,
+ /// The scalar type for B.
+ ScalarB_,
+ /// The scalar type for C.
+ ScalarC_,
+ /// The scalar type for D.
+ ScalarD_,
+ /// The tile size for the GEMM KxNxM.
+ OutputTile_,
+ /// The functor to do the math in the main loop.
+ ThreadMultiplyAdd, ScalarA_, ScalarB_, float /*for sgemm accum is float*/>,
+ /// The number of scalars per LDG for A.
+ kScalarsPerLdgA_,
+ /// The number of scalars per STS for A.
+ kScalarsPerLdgA_,
+ /// The number of scalars per LDS for A.
+ 4,
+ /// The number of scalars per LDG for B.
+ kScalarsPerLdgB_,
+ /// The number of scalars per STS for B.
+ kScalarsPerLdgB_,
+ /// The number of scalars per LDS for B.
+ 4,
+ /// The number of scalars per LDG for C and STG for D.
+ 1,
+ /// The number of scalars per STS for D.
+ 4,
+ /// The number of scalars per LDS for D.
+ 1,
+ /// The number of stages in shared memory.
+ 2> {};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// The layout for A.
+ MatrixLayout::Kind kLayoutA_,
+ /// The layout for B.
+ MatrixLayout::Kind kLayoutB_,
+ /// The output tile.
+ typename OutputTile_ = Shape<8, 128, 128>,
+ /// The type for A
+ typename ScalarA_ = half,
+ /// The type for B
+ typename ScalarB_ = half,
+ /// The type for C
+ typename ScalarC_ = half,
+ /// The type for D
+ typename ScalarD_ = half,
+ /// the Type for alpha and beta,
+ typename Scalar_ = half,
+ /// The functor to use in the epilogue.
+ typename EpilogueFunctor_ = LinearScaling >,
+ /// Tile size for thread-level GEMM (K-by-N-by-M)
+ typename ThreadGemmShape_ = Shape<8, 8, 8>,
+ /// The number of floats loaded in one LDG for A.
+ int kScalarsPerLdgA_ = 1,
+ /// The number of floats loaded in one LDG for B.
+ int kScalarsPerLdgB_ = 1,
+ /// The index.
+ typename Index_ = int,
+ /// The SGEMM config.
+ typename GemmConfig_ =
+ Fp16SgemmConfig,
+ /// The traits class for the epilogue.
+ typename GemmEpilogueTraits_ =
+ SimplifiedGemmEpilogueTraits >
+struct Fp16SgemmSgemmTraits : public SimplifiedGemmTraits<
+ // The layout for A.
+ kLayoutA_,
+ // The layout for B.
+ kLayoutB_,
+ // The config.
+ GemmConfig_,
+ // The epilogue.
+ GemmEpilogue,
+ // The index.
+ Index_> {};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/gemm.h b/cutlass/gemm/gemm.h
index c50a3f04b4..6340ab4f33 100644
--- a/cutlass/gemm/gemm.h
+++ b/cutlass/gemm/gemm.h
@@ -31,16 +31,32 @@
#include
#endif
-#include
-#include
-
+#include "cutlass/coord.h"
+#include "cutlass/util/platform.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
+/// GEMM kernel with launch bounds specified
+template
+__global__ __launch_bounds__(Gemm_::kThreads)
+void gemm_kernel(typename Gemm_::Params params) {
+ // Declare shared memory.
+ __shared__ typename Gemm_::SharedStorage shared_storage;
+
+ // Construct the GEMM object.
+ Gemm_ gemm(params, shared_storage);
+ // Run GEMM.
+ gemm.multiply_add();
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// GEMM kernel without launch bounds specified
template
-__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) {
+__global__ /* __launch_bounds__(Gemm_::kThreads) */
+void gemm_kernel_nolb(typename Gemm_::Params params) {
// Declare shared memory.
__shared__ typename Gemm_::SharedStorage shared_storage;
@@ -52,28 +68,22 @@ __global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
-struct GemmDesc {
- /// The dimensions of the GEMM.
- Index_ m, n, k;
- /// The alpha/beta scaling values.
- Scalar_ alpha, beta;
- /// The source matrix A.
- void const* d_a;
- /// The stride for A.
- Index_ lda;
- /// The source matrix B.
- void const* d_b;
- /// The stride for B.
- Index_ ldb;
- /// The source matrix C.
- void const* d_c;
- /// The stride for C.
- Index_ ldc;
- /// The destination matrix D.
- void* d_d;
- /// The stride for D.
- Index_ ldd;
+/// Partial specialization for launching the GEMM kernel with or without launch bounds
+template
+struct Launch {
+ Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
+ gemm_kernel<<< grid, block, 0, stream >>>(params);
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// Partial specialization for launching the GEMM kernel with or without launch bounds
+template
+struct Launch {
+ Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) {
+ gemm_kernel_nolb<<< grid, block, 0, stream >>>(params);
+ }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -100,86 +110,52 @@ struct Gemm {
/// The index.
typedef typename Traits::Index Index;
+ /// Define the mainloop iteration size
+ typedef typename Traits::MultiplyAdd MultiplyAdd;
+
/// The number of threads.
static int const kThreads = Traits::GemmConfig::kThreads;
- /// The params.
- struct Params : public Traits::Params {
- CUTLASS_HOST_DEVICE int initialize(Index m,
- Index n,
- Index k,
- ScalarEpilogue alpha,
- ScalarA const* d_a,
- Index lda,
- ScalarB const* d_b,
- Index ldb,
- ScalarEpilogue beta,
- ScalarC const* d_c,
- Index ldc,
- ScalarD* d_d,
- Index ldd) {
- GemmDesc desc;
- desc.m = m;
- desc.n = n;
- desc.k = k;
- desc.alpha = alpha;
- desc.beta = beta;
- desc.d_a = reinterpret_cast(d_a);
- desc.lda = lda;
- desc.d_b = reinterpret_cast(d_b);
- desc.ldb = ldb;
- desc.d_c = reinterpret_cast(d_c);
- desc.ldc = ldc;
- desc.d_d = reinterpret_cast(d_d);
- desc.ldd = ldd;
- return Traits::Params::initialize(desc);
- }
- };
+ // Number of warp-level multiply-accumulate steps executed by each warp.
+ static Index const kWarpGemmSteps =
+ Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD;
+
+ // Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
+ static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps");
+ /// Use the params object defined in traits
+ typedef typename Traits::Params Params;
+
+//
+// Static function members
+//
+
+/// Support for NVRTC
#if !defined(__CUDACC_RTC__)
/// Launch the kernel.
static __host__ cudaError_t launch(Params const& params,
cudaStream_t stream = cudaStreamDefault) {
- // Setup the grid.
- dim3 grid;
- grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
- grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
-
- // The number of threads.
- dim3 block;
- block.x = kThreads;
// Launch the kernel.
- void const* params_ = reinterpret_cast(¶ms);
-
- return cudaLaunchKernel(reinterpret_cast(&gemm_kernel),
- grid,
- block,
- const_cast(¶ms_),
- 0,
- stream);
+ Launch(
+ params, params.grid, params.block, stream);
+
+ return cudaGetLastError();
}
/// Launch the kernel.
static __host__ cudaError_t launch(CUfunction kernel,
Params const& params,
CUstream stream = CU_STREAM_LEGACY) {
- // Setup the grid.
- dim3 grid;
- grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW;
- grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH;
-
- // The number of threads.
- dim3 block;
- block.x = kThreads;
// Launch the kernel.
void* params_[] = {const_cast(reinterpret_cast(¶ms))};
- // return cudaLaunchKernel(reinterpret_cast(&gemm_kernel), grid, block,
- // const_cast(¶ms_), 0, stream);
CUresult result = cuLaunchKernel(
- kernel, grid.x, grid.y, grid.z, block.x, block.y, block.z, 0, stream, params_, 0);
+ kernel,
+ params.grid.x, params.grid.y, params.grid.z,
+ params.block.x, params.block.y, params.block.z,
+ 0, stream, params_, 0);
if (result != CUDA_SUCCESS) {
return cudaErrorLaunchFailure;
@@ -189,39 +165,41 @@ struct Gemm {
#endif
+ //
+ // Methods
+ //
+
/// Ctor.
CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_)
: params(params_), shared_storage(shared_storage_) {}
- /// Consume a single iteration of the loop.
- template
- CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream,
- typename Traits::SharedLoadStream& shared_load_stream,
- typename Traits::MultiplyAdd::Accumulators& accumulators,
+ /// Computes a warp-level GEMM on data held in shared memory
+ template
+ CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream,
+ typename Traits::SharedStream& shared_load_stream,
+ typename MultiplyAdd::Accumulators& accumulators,
Index outer_k) {
- // If that's the last "load iteration" update the predicates.
- if (!kIsLastIteration) {
- global_stream.move_to_residue(outer_k);
+ // If residue portion and not calculating residue in prolog, update residue predicates now.
+ if (Residue && outer_k <= Traits::OutputTile::kD) {
+ global_to_shared_stream.residue(outer_k);
}
- // Load data for the next iteration of the main loop.
- if (!kIsLastIteration) {
- global_stream.copy();
+ // Load data for the next iteration of the main loop (unless it's the last iteration).
+ if (!LastIteration) {
+ global_to_shared_stream.copy();
}
- // The unrolling steps for the main loop.
- int const kUnrollingSteps =
- Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
-
CUTLASS_PRAGMA_UNROLL
- for (int step = 0; step < kUnrollingSteps - 1; ++step) {
+ for (int step = 0; step < kWarpGemmSteps - 1; ++step) {
// Trigger the copy from shared memory for the next A/B values.
shared_load_stream.copy(step + 1);
+
// Make sure the values are available for the current iteration to do the multiply-add.
shared_load_stream.commit(step);
+ MultiplyAdd multiply_add;
+
// Do the math on the fragments of the current iteration.
- typename Traits::MultiplyAdd multiply_add;
multiply_add.multiply_add(shared_load_stream.fragment_a(step),
shared_load_stream.fragment_b(step),
accumulators,
@@ -232,28 +210,25 @@ struct Gemm {
Traits::shared_load_fence(true);
// Commit the data in shared memory for A/B.
- if (!kIsLastIteration) {
- global_stream.commit();
+ if (!LastIteration) {
+ global_to_shared_stream.commit();
}
-
// Make sure the data is in shared memory.
Traits::shared_store_fence(true);
- // Trigger the loads for the next iteration (if needed).
- if (!kIsLastIteration) {
+ if (!LastIteration) {
// Move to the next stage for the load (if it makes sense).
shared_load_stream.inc_stage();
// Trigger the copy from shared memory for the next loop iteration.
shared_load_stream.copy(0);
}
-
// Make sure the values are available for the current iteration to do the multiply-add.
- shared_load_stream.commit(kUnrollingSteps - 1);
+ shared_load_stream.commit(kWarpGemmSteps - 1);
// Do the math on the fragments of the current iteration.
- typename Traits::MultiplyAdd multiply_add;
- multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1),
- shared_load_stream.fragment_b(kUnrollingSteps - 1),
+ MultiplyAdd multiply_add;
+ multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1),
+ shared_load_stream.fragment_b(kWarpGemmSteps - 1),
accumulators,
accumulators);
}
@@ -262,76 +237,112 @@ struct Gemm {
CUTLASS_DEVICE void multiply_add() {
// Swizzle the IDs of the block (to enable better cache behavior).
typename Traits::BlockSwizzle block_swizzle;
- dim3 block = block_swizzle.swizzle();
-
- // Scale the id.
- block.x *= Traits::OutputTile::kW;
- block.y *= Traits::OutputTile::kH;
+ Coord<3> threadblock_offset =
+ block_swizzle.get_threadblock_offset(make_Coord_from_shape());
// We may want to use shared memory to clear the registers.
typedef typename Traits::ClearAccumulators ClearAccumulators;
// The streams to read A/B from global memory to shared memory.
- typename Traits::GlobalLoadStream global_stream(params, shared_storage, block);
+ typename Traits::GlobalLoadStream global_to_shared_stream(
+ params.global_to_shared_stream,
+ shared_storage.main_loop.global_to_shared_stream,
+ shared_storage.main_loop.threadblock_tile.reference(),
+ params.problem_size.knm(),
+ threadblock_offset);
- // Create the accumulator clear.
- ClearAccumulators clear(shared_storage.main_loop.clear);
+ // update A and B pointer offset based on batch_id and batch_stride_offset
+ //global_to_shared_stream.add_pointer_offset(block_swizzle.get_batch_id(), params.batch_stride_A, params.batch_stride_B);
+ global_to_shared_stream += make_Coord(block_swizzle.get_batch_id(), 0, 0);
- // By how much we unroll the main loop.
- Index const kUnroll = static_cast(Traits::OutputTile::kD);
+ // Create the accumulator clear.
+ ClearAccumulators clear;
- // If we do not have enough steps in the main loop, trigger the residue code.
- global_stream.move_to_residue(params.k);
+ // Deal with residue in prolog.
+ global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD);
// Fetch the fragments for A and B from global memory.
- global_stream.copy();
+ global_to_shared_stream.copy();
// Copy the elements to shared memory (after transformation if needed).
- global_stream.commit();
+ global_to_shared_stream.commit();
// Make sure the data is in shared memory.
Traits::shared_store_fence(false);
- // Rollback to the beginning of the GEMM-K dimension. It may have no impact.
- global_stream.rollback();
-
- // The unrolling steps for the main loop.
- int const kUnrollingSteps =
- Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD;
-
- // Make sure we have at least 2 unrolling steps or our pipeling is not going to work.
- static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps");
+ // Rollback to the beginning of the first tile (if residue exists).
+ global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD);
// The stream of data from shared memory to fragments.
- typename Traits::SharedLoadStream shared_load_stream(params, shared_storage);
+ typename Traits::SharedStream shared_load_stream(
+ params.shared_stream,
+ shared_storage.main_loop.threadblock_tile.reference());
// Trigger the copy from shared memory for the 1st stream.
shared_load_stream.copy(0);
// Allocate the accumulators.
- typename Traits::MultiplyAdd::Accumulators accumulators;
+ typename MultiplyAdd::Accumulators accumulators;
+
// Clear the accumulators.
clear.clear(accumulators);
- // The loop index.
- Index outer_k = params.k - kUnroll;
+ // Initial index
+ Index outer_k = params.problem_size[0] - Traits::OutputTile::kD;
- // Enter the main loop and iterate.
- for (; outer_k > 0; outer_k -= kUnroll) {
- consume_tile(global_stream, shared_load_stream, accumulators, outer_k);
- }
+ // Check if we are computing residue in prolog or not.
+ if (Traits::GemmConfig::kResidueInProlog) {
+
+ // Execute all mainloop iterations but the last one.
+
+ CUTLASS_GEMM_LOOP
+ for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) {
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
+
+ }
+
+ // Don't load data for the last "residue" portion since we've already computed the residue.
+ CUTLASS_GEMM_LOOP
+ for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
- // Residual loop.
- for (; outer_k > -kUnroll; outer_k -= kUnroll) {
- consume_tile(global_stream, shared_load_stream, accumulators, outer_k);
+ }
+ } else {
+ // When kResidueSeparate = true, execute all mainloop iterations but the last two without any
+ // consideration for K-residue or predicate updates. This improves the steady state of some
+ // kernels.
+ if (Traits::GemmConfig::kResidueSeparate) {
+
+ CUTLASS_GEMM_LOOP
+ for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
+
+ }
+ }
+
+ // Execute remaining tiles with K-residue predicate updates enabled.
+
+ CUTLASS_GEMM_LOOP
+ for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) {
+ consume_tile(
+ global_to_shared_stream, shared_load_stream, accumulators, outer_k);
+
+ }
}
// Epilogue.
typedef typename Traits::Epilogue Epilogue;
- Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.m, params.n);
- epilogue.epilogue(cutlass::make_Coord(0, block.y, block.x), accumulators);
+ Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm());
+ epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id());
}
+ //
+ // Data members
+ //
+
/// The params.
Params const& params;
/// The shared storage.
diff --git a/cutlass/gemm/gemm_config.h b/cutlass/gemm/gemm_config.h
new file mode 100644
index 0000000000..76df0add62
--- /dev/null
+++ b/cutlass/gemm/gemm_config.h
@@ -0,0 +1,145 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Defines properties of GEMM computation that impose some constraints on caller.
+*/
+#pragma once
+
+#include "cutlass/shape.h"
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template <
+ /// The scalar type for A.
+ typename ScalarA_,
+ /// The scalar type for B.
+ typename ScalarB_,
+ /// The scalar type for C.
+ typename ScalarC_,
+ /// The scalar type for D.
+ typename ScalarD_,
+ /// The threadblock tile size for the GEMM KxNxM.
+ typename OutputTile_,
+ /// The functor to do the math.
+ typename MultiplyAdd_,
+ /// The number of scalars per LDG for A.
+ int kScalarsPerLdgA_,
+ /// The number of scalars per STS for A.
+ int kScalarsPerStsA_,
+ /// The number of scalars per LDG for A.
+ int kScalarsPerLdsA_,
+ /// The number of scalars per LDG for B.
+ int kScalarsPerLdgB_,
+ /// The number of scalars per STS for B.
+ int kScalarsPerStsB_,
+ /// The number of scalars per LDS for B.
+ int kScalarsPerLdsB_,
+ /// The number of scalars per LDG for C and STG for D.
+ int kScalarsPerLdgCAndStgD_,
+ /// The number of scalars per STS for D.
+ int kScalarsPerStsD_,
+ /// The number of scalars per LDS for D.
+ int kScalarsPerLdsD_,
+ /// The number of stages in shared memory to do single/double/triple-buffering.
+ int kStages_,
+ /// If true, residue is computed in mainloop. If false, separate loops are instantiated.
+ bool kResidueSeparate_ = false,
+ /// Is residue performed in prologue?
+ bool kResidueInProlog_ = false,
+ /// If true, kernel is launched with CUDA launch bounds specified
+ bool kLaunchBounds_ = true>
+struct GemmConfig {
+ //
+ /// The scalar for A.
+ typedef ScalarA_ ScalarA;
+ /// The scalar for B.
+ typedef ScalarB_ ScalarB;
+ /// The scalar for C.
+ typedef ScalarC_ ScalarC;
+ /// The scalar for D.
+ typedef ScalarD_ ScalarD;
+
+ /// The tile.
+ typedef OutputTile_ OutputTile;
+ /// The functor to do D = A*B + C.
+ typedef MultiplyAdd_ MultiplyAdd;
+ /// The shape of the instruction.
+ typedef typename MultiplyAdd::InstructionShape InstructionShape;
+ /// The shape of warp-level GEMM
+ typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp;
+ /// The accumulators.
+ typedef typename MultiplyAdd::Accumulators Accumulators;
+
+ /// The number of warps.
+ typedef typename ShapeDiv::Shape Warps;
+ /// The default warp size (32 threads per warp).
+ static int const kWarpSize = cutlass::kWarpSize;
+ /// The numnber of threads.
+ static int const kThreads = ShapeCount::kCount * kWarpSize;
+
+ /// The number of scalars per LDG/STS/LDS for A.
+ static int const kScalarsPerLdgA = kScalarsPerLdgA_;
+ static int const kScalarsPerStsA = kScalarsPerStsA_;
+ static int const kScalarsPerLdsA = kScalarsPerLdsA_;
+
+ /// The number of scalars per LDG/STS/LDS for B.
+ static int const kScalarsPerLdgB = kScalarsPerLdgB_;
+ static int const kScalarsPerStsB = kScalarsPerStsB_;
+ static int const kScalarsPerLdsB = kScalarsPerLdsB_;
+
+ /// The number of scalars per LDG for C.
+ static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_;
+
+ /// The number of scalars per STS/LDS/STG for D.
+ static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_;
+ static int const kScalarsPerStsD = kScalarsPerStsD_;
+ static int const kScalarsPerLdsD = kScalarsPerLdsD_;
+
+ /// The number of accumulators that are going to be fed from one LDS A/B.
+ static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD;
+ static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD;
+
+ /// The number of stages in shared memory to implement double, triple, more-buffering.
+ static int const kStages = kStages_;
+
+ /// If true, mainloop is instantiated twice. The first instantiation contains no predicate
+ // updates and is more efficient for some kernels. If false, only a single mainloop is
+ // instantaited.
+ static bool const kResidueSeparate = kResidueSeparate_;
+
+ /// If true, residue is computed in the prologue.
+ static bool const kResidueInProlog = kResidueInProlog_;
+
+ /// If true, kernel is launched with launch bounds specified
+ static bool const kLaunchBounds = kLaunchBounds_;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/gemm_coord.h b/cutlass/gemm/gemm_coord.h
new file mode 100644
index 0000000000..8e36bb0430
--- /dev/null
+++ b/cutlass/gemm/gemm_coord.h
@@ -0,0 +1,203 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief GemmCoord is a structure derived from Coord<4> that specifies a location within the
+ coordinate system of a GEMM problem.
+*/
+
+#pragma once
+
+#include "cutlass/cutlass.h"
+#include "cutlass/coord.h"
+#include "cutlass/util/platform.h"
+
+namespace cutlass {
+namespace gemm {
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+/// GemmCoord is a structure derived from Coord<4> that specifies a location within the
+/// coordinate space of a GEMM problem.
+struct GemmCoord : public Coord<4, int> {
+
+ /// Integer-valued index
+ typedef int Index;
+
+ /// Base type is a Coord of rank=4
+ typedef Coord<4, Index> Base;
+
+ /// GEMM K dimension - inner dimension of the GEMM problem
+ static int const kK = 0;
+
+ /// GEMM N dimension - columns of the output C matrix
+ static int const kN = 1;
+
+ /// GEMM M dimension - rows of the output C matrix
+ static int const kM = 2;
+
+ /// Batch dimension - for generalizing to larger problems
+ static int const kBatch = 3;
+
+ //
+ // Methods
+ //
+
+ /// Default ctor
+ CUTLASS_HOST_DEVICE
+ GemmCoord() { }
+
+ /// Constructs from Coord<3> and a batch
+ CUTLASS_HOST_DEVICE
+ GemmCoord(Coord<3, Index> const &coord, Index _batch = 0): Base(make_Coord(coord[0], coord[1], coord[2], _batch)) { }
+
+ /// Constructs from Coord<4>
+ CUTLASS_HOST_DEVICE
+ GemmCoord(Coord<4, Index> const &coord): Base(coord) { }
+
+ /// Constructs from an array of coordinate elements
+ CUTLASS_HOST_DEVICE
+ GemmCoord(Index coord[4]): Base(coord) { }
+
+ /// Helper to construct from a K, N, M, batch variables
+ CUTLASS_HOST_DEVICE
+ GemmCoord(Index k, Index n, Index m, Index batch = 0): Base(make_Coord(k, n, m, batch)) { }
+
+ /// Returns the GEMM M coordinate
+ CUTLASS_HOST_DEVICE
+ Index const & m() const { return this->at(kM); }
+
+ /// Returns reference to the GEMM M coordinate
+ CUTLASS_HOST_DEVICE
+ Index & m() { return this->at(kM); }
+
+ /// Returns the GEMM N coordinate
+ CUTLASS_HOST_DEVICE
+ Index const & n() const { return this->at(kN); }
+
+ /// Returns reference to the GEMM N coordinate
+ CUTLASS_HOST_DEVICE
+ Index & n() { return this->at(kN); }
+
+ /// Returns the GEMM K coordinate
+ CUTLASS_HOST_DEVICE
+ Index const & k() const { return this->at(kK); }
+
+ /// Returns reference to the GEMM K coordinate
+ CUTLASS_HOST_DEVICE
+ Index & k() { return this->at(kK); }
+
+ /// Returns the GEMM batch coordinate
+ CUTLASS_HOST_DEVICE
+ Index const & batch() const { return this->at(kBatch); }
+
+ /// Returns reference to the GEMM batch coordinate
+ CUTLASS_HOST_DEVICE
+ Index & batch() { return this->at(kBatch); }
+
+ /// Obtains a Coord<3> from GemmCoord
+ CUTLASS_HOST_DEVICE
+ Coord<3> knm() const {
+ return make_Coord(k(), n(), m());
+ }
+
+ /// Obtains a Coord<2> from GemmCoord
+ CUTLASS_HOST_DEVICE
+ Coord<2> nm() const {
+ return make_Coord(n(), m());
+ }
+
+ /// Obtains a Coord<2> from GemmCoord
+ CUTLASS_HOST_DEVICE
+ Coord<2> km() const {
+ return make_Coord(k(), m());
+ }
+
+ /// Obtains a Coord<2> from GemmCoord
+ CUTLASS_HOST_DEVICE
+ Coord<2> kn() const {
+ return make_Coord(k(), n());
+ }
+
+ //
+ // Coord operators
+ //
+
+ /// Element-wise addition
+ CUTLASS_HOST_DEVICE
+ GemmCoord operator+(Base const& b) const {
+ return GemmCoord(Base::operator+(b));
+ }
+
+ /// Element-wise subtraction
+ CUTLASS_HOST_DEVICE
+ GemmCoord operator-(Base const& b) const {
+ return GemmCoord(Base::operator-(b));
+ }
+
+ /// Element-wise multiplication
+ CUTLASS_HOST_DEVICE
+ GemmCoord operator*(Base const& b) const {
+ return GemmCoord(Base::operator*(b));
+ }
+
+ /// Element-wise division
+ CUTLASS_HOST_DEVICE
+ GemmCoord operator/(Base const& b) const {
+ return GemmCoord(Base::operator/(b));
+ }
+
+ /// In-place addition
+ CUTLASS_HOST_DEVICE
+ GemmCoord& operator+=(Base const& b) {
+ Base::operator+=(b);
+ return *this;
+ }
+
+ /// In-place subtraction
+ CUTLASS_HOST_DEVICE
+ GemmCoord& operator-=(Base const& b) {
+ Base::operator-=(b);
+ return *this;
+ }
+
+ /// In-place multiplication
+ CUTLASS_HOST_DEVICE
+ GemmCoord& operator*=(Base const& b) {
+ Base::operator*=(b);
+ return *this;
+ }
+
+ /// In-place division
+ CUTLASS_HOST_DEVICE
+ GemmCoord& operator/=(Base const& b) {
+ Base::operator/=(b);
+ return *this;
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/gemm_desc.h b/cutlass/gemm/gemm_desc.h
new file mode 100644
index 0000000000..80f4b36557
--- /dev/null
+++ b/cutlass/gemm/gemm_desc.h
@@ -0,0 +1,205 @@
+/***************************************************************************************************
+ * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without modification, are permitted
+ * provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright notice, this list of
+ * conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright notice, this list of
+ * conditions and the following disclaimer in the documentation and/or other materials
+ * provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
+ * to endorse or promote products derived from this software without specific prior written
+ * permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
+ * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
+ * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+ * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
+ * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
+ * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ **************************************************************************************************/
+/*! \file
+ \brief Implements a software-pipelined efficient GEMM.
+*/
+#pragma once
+
+#include "cutlass/tensor_ref.h"
+#include "cutlass/gemm/gemm_coord.h"
+
+namespace cutlass {
+namespace gemm {
+
+/// GEMM problem description
+template <
+ /// Source accumulator matrix type
+ typename AType_,
+ /// Destination accumulator type
+ typename BType_,
+ /// Source accumulator matrix type
+ typename CType_,
+ /// Destination accumulator type
+ typename DType_,
+ /// Scalar type for alpha and beta
+ typename SType_,
+ /// Index type for dimensions and strides
+ typename Index_ = int
+> struct GemmDesc {
+ //
+ // Type definitions
+ //
+
+ /// Index type for dimensions and strides
+ typedef Index_ Index;
+
+ /// Source accumulator matrix type
+ typedef AType_ AType;
+
+ /// Tensor reference to A operand
+ typedef TensorRef TensorRefA;
+
+ /// Destination accumulator type
+ typedef BType_ BType;
+
+ /// Tensor reference to B operand
+ typedef TensorRef TensorRefB;
+
+ /// Source accumulator matrix type
+ typedef CType_ CType;
+
+ /// Tensor reference to C operand
+ typedef TensorRef TensorRefC;
+
+ /// Destination accumulator type
+ typedef DType_ DType;
+
+ /// Tensor reference to D operand
+ typedef TensorRef TensorRefD;
+
+ /// Scalar type for alpha and beta
+ typedef SType_ SType;
+
+ //
+ // Data members
+ //
+
+ /// The dimensions of the GEMM.
+ GemmCoord problem_size;
+
+ /// The alpha scaling values.
+ SType alpha;
+
+ /// The source matrix A.
+ TensorRefA A;
+
+ /// batch stride for A operand
+ long long batch_stride_A;
+
+ /// The source matrix B.
+ TensorRefB B;
+
+ /// batch stride for B operand
+ long long batch_stride_B;
+
+ /// The beta scaling values.
+ SType beta;
+
+ /// The source matrix C.
+ TensorRefC C;
+
+ /// batch stride for C operand
+ long long batch_stride_C;
+
+ /// The destination matrix D.
+ TensorRefD D;
+
+ /// batch stride for D operand
+ long long batch_stride_D;
+
+ //
+ // Methods
+ //
+
+ /// Default ctor
+ CUTLASS_HOST_DEVICE
+ GemmDesc(): problem_size(0, 0, 0, 1), alpha(1), beta(0) {}
+
+ /// Constructor for basic GEMM with batch count = 1
+ CUTLASS_HOST_DEVICE
+ GemmDesc(Coord<3> _problem_size,
+ SType _alpha,
+ TensorRefA const &_A,
+ TensorRefB const &_B,
+ SType _beta,
+ TensorRefC const &_C,
+ TensorRefD const &_D
+ ):
+ problem_size(_problem_size[0], _problem_size[1], _problem_size[2], 1),
+ alpha(_alpha),
+ A(_A),
+ batch_stride_A(0),
+ B(_B),
+ batch_stride_B(0),
+ beta(_beta),
+ C(_C),
+ batch_stride_C(0),
+ D(_D),
+ batch_stride_D(0) {}
+
+ /// Constructor for basic GEMM with batch count = 1
+ CUTLASS_HOST_DEVICE
+ GemmDesc(GemmCoord _problem_size,
+ SType _alpha,
+ TensorRefA const &_A,
+ TensorRefB const &_B,
+ SType _beta,
+ TensorRefC const &_C,
+ TensorRefD const &_D
+ ):
+ problem_size(_problem_size.k(), _problem_size.n(), _problem_size.m(), 1),
+ alpha(_alpha),
+ A(_A),
+ batch_stride_A(0),
+ B(_B),
+ batch_stride_B(0),
+ beta(_beta),
+ C(_C),
+ batch_stride_C(0),
+ D(_D),
+ batch_stride_D(0) {
+
+ assert(_problem_size.batch() == 1);
+ }
+
+ /// Constructor for strided batch GEMM GEMM
+ CUTLASS_HOST_DEVICE
+ GemmDesc(GemmCoord _problem_size,
+ SType _alpha,
+ TensorRefA const &_A,
+ long long _batch_stride_A,
+ TensorRefB const &_B,
+ long long _batch_stride_B,
+ SType _beta,
+ TensorRefC const &_C,
+ long long _batch_stride_C,
+ TensorRefD const &_D,
+ long long _batch_stride_D
+ ):
+ problem_size(_problem_size),
+ alpha(_alpha),
+ A(_A),
+ batch_stride_A(_batch_stride_A),
+ B(_B),
+ batch_stride_B(_batch_stride_B),
+ beta(_beta),
+ C(_C),
+ batch_stride_C(_batch_stride_C),
+ D(_D),
+ batch_stride_D(_batch_stride_D) {}
+};
+
+} // namespace gemm
+} // namespace cutlass
diff --git a/cutlass/gemm/gemm_epilogue.h b/cutlass/gemm/gemm_epilogue.h
index bc25307775..d9469bb550 100644
--- a/cutlass/gemm/gemm_epilogue.h
+++ b/cutlass/gemm/gemm_epilogue.h
@@ -29,26 +29,15 @@
*/
#pragma once
-#include
-#include
-#include
+#include "cutlass/convert.h"
+#include "cutlass/coord.h"
+#include "cutlass/fragment.h"
namespace cutlass {
namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
-template
-CUTLASS_DEVICE bool is_zero(T x) {
- return x == T(0);
-}
-
-#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
-CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast(x) == int16_t(0); }
-#endif
-
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
template
struct GemmEpilogue {
/// The traits class.
@@ -85,9 +74,7 @@ struct GemmEpilogue {
/// The shared store transformer for D.
typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD;
/// The iterator to load D in shared memory.
- typedef typename Traits::SharedLoadIteratorD SharedLoadIteratorD;
- /// The shared load transformer for D.
- typedef Copy SharedLoadTransformerD;
+ typedef typename Traits::SharedLoadStreamD SharedLoadStreamD;
/// The index.
typedef typename Traits::Index Index;
@@ -100,33 +87,28 @@ struct GemmEpilogue {
/// Ctor.
CUTLASS_DEVICE GemmEpilogue(Params const& params_,
SharedStorage& shared_storage_,
- Index m_,
- Index n_)
- : params(params_), shared_storage(shared_storage_), m(m_), n(n_) {}
+ Coord<3> const& _problem_size)
+ : params(params_), shared_storage(shared_storage_), problem_size(_problem_size), functor(params_.functor) {}
/// Execute the epilogue.
- CUTLASS_DEVICE void epilogue(Coord<3> const& block, Accumulators& accumulators) {
- if (is_zero(params.functor.beta)) {
- epilogue_with_or_without_beta(block, accumulators);
+ CUTLASS_DEVICE void epilogue(Accumulators& accumulators,
+ Coord<3> const& block = make_Coord(0, 0, 0),
+ int batch_id = 0) {
+ if (functor.source_required()) {
+ epilogue_with_or_without_beta(accumulators, block, batch_id);
} else {
- epilogue_with_or_without_beta(block, accumulators);
+ epilogue_with_or_without_beta(accumulators, block, batch_id);
}
}
- template
- CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block,
- Accumulators& accumulators) {
-
- // The problem size.
- Coord<3> const bounds = cutlass::make_Coord(0, n, m);
-
- // The functor.
- Functor functor(params.functor);
+ template
+ CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators& accumulators,
+ Coord<3> const& block,
+ int batch_id) {
// The C fragment.
typename GlobalLoadIteratorC::Fragment fragment_c;
// The transformed C fragment.
typename GlobalTransformerC::OutputFragment transformed_c;
-
CUTLASS_PRAGMA_UNROLL
for (int h = 0; h < Iterations::kH; ++h) {
// Compute pointer and predicate offsets for C and D global iterators.
@@ -136,6 +118,7 @@ struct GemmEpilogue {
Iterations::kW +
params.stride_h) *
h;
+
int const predicate_offset =
((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) +
params.iterator_d.predicate_inc_advance) *
@@ -145,32 +128,40 @@ struct GemmEpilogue {
// The iterator to load the elements of the C matrix.
GlobalLoadIteratorC global_load_iterator(
- params.iterator_c, bounds, block, pointer_offset, predicate_offset);
+ params.iterator_c, problem_size, block, pointer_offset, predicate_offset);
+
+ // update C pointer offset based on batch_id and batch_stride_offset
+ //global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_c);
+ global_load_iterator += make_Coord(batch_id, 0, 0);
+
// The transformer for C.
GlobalTransformerC transformer_c;
// The transformer for D.
GlobalTransformerD transformer_d;
// The iterator to store into the D matrix.
GlobalStoreIteratorD global_store_iterator(
- params.iterator_d, bounds, block, pointer_offset, predicate_offset);
+ params.iterator_d, problem_size, block, pointer_offset, predicate_offset);
+
+ // update D pointer offset based on batch_id and batch_stride_offset
+ //global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_d);
+ global_store_iterator += make_Coord(batch_id, 0, 0);
- // The transformer to transform before storing to shared memory.
SharedStoreTransformerD shared_store_transformer;
typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d;
- // The iterator to store to shared memory.
- SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d,
- shared_storage.shared_stream.store);
+ SharedStoreIteratorD shared_store_iterator(
+ params.shared_store_iterator_d,
+ reinterpret_cast(shared_storage.data()));
- // The iterator to load from shared memory. TODO: Use a stream.
- SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d,
- shared_storage.shared_stream.load);
+ SharedLoadStreamD shared_load_stream(
+ params.shared_load_stream_d,
+ reinterpret_cast(shared_storage.data()));
CUTLASS_PRAGMA_UNROLL
for (int w = 0; w < Iterations::kW; ++w) {
// Load the C matrix into fragment.
- if (!kBetaIsZero_) {
- iterator_load(global_load_iterator, fragment_c);
+ if (kSourceRequired) {
+ global_load_iterator.load_post_increment(fragment_c);
}
// Make sure we can write to shared memory.
@@ -180,33 +171,33 @@ struct GemmEpilogue {
int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements;
shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d);
- shared_iterator_store(shared_store_iterator, shared_store_transformed_d);
+ shared_store_iterator.store_post_increment(shared_store_transformed_d);
// Make sure the data is in shared memory.
shared_store_fence();
// Copy the accumulators back to registers from shared memory.
- typename SharedLoadIteratorD::Fragment fetched_d;
- shared_iterator_load(shared_load_iterator, fetched_d);
+ shared_load_stream.copy();
+ shared_load_stream.commit();
// Do the math.
typename GlobalTransformerD::InputFragment fragment_d;
- if (kBetaIsZero_) {
- functor.evaluate(fetched_d, fragment_d);
- } else {
+ if (kSourceRequired) {
// Transform C fragment.
transformer_c.transform(fragment_c, transformed_c);
// Do the math.
- functor.evaluate(fetched_d, transformed_c, fragment_d);
+ functor.evaluate(shared_load_stream.fragment(), transformed_c, fragment_d);
+ } else {
+ functor.evaluate(shared_load_stream.fragment(), fragment_d);
}
// Transform D fragment.
- typename GlobalTransformerD::OutputFragment transformed_d;
- transformer_d.transform(fragment_d, transformed_d);
+ typename GlobalTransformerD::OutputFragment global_transformed_d;
+ transformer_d.transform(fragment_d, global_transformed_d);
// Copy the results to global memory.
- iterator_store(global_store_iterator, transformed_d);
+ global_store_iterator.store_post_increment(global_transformed_d);
}
}
}
@@ -222,7 +213,9 @@ struct GemmEpilogue {
/// The shared storage.
SharedStorage& shared_storage;
/// The dimensions of the GEMM.
- Index m, n;
+ Coord<3> problem_size;
+ // The functor.
+ Functor functor;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/cutlass/gemm/gemm_epilogue_traits.h b/cutlass/gemm/gemm_epilogue_traits.h
index c06fc25026..c6aff71e14 100644
--- a/cutlass/gemm/gemm_epilogue_traits.h
+++ b/cutlass/gemm/gemm_epilogue_traits.h
@@ -27,13 +27,13 @@
*/
#pragma once
-#include
-#include
-#include
-#include
-#include
-#include
-#include
+#include "cutlass/convert.h"
+#include "cutlass/coord.h"
+#include "cutlass/gemm/gemm_global_stream.h"
+#include "cutlass/gemm/gemm_shared_stream.h"
+#include "cutlass/gemm/linear_scaling.h"
+#include "cutlass/reshape_tile.h"
+#include "cutlass/tile_iterator.h"
namespace cutlass {
namespace gemm {
@@ -57,8 +57,8 @@ template <
typename SharedStoreIteratorD_,
/// The shared store transformer for D.
typename SharedStoreTransformerD_,
- /// The iterator to load D from shared memory.
- typename SharedLoadIteratorD_,
+ /// The stream to load D from shared memory.
+ typename SharedLoadStreamD_,
/// The number of iterations in the epilogue.
typename Iterations_,
/// The iterations strides.
@@ -86,8 +86,8 @@ struct GemmEpilogueTraits {
typedef SharedStoreIteratorD_ SharedStoreIteratorD;
/// The shared store transformer for D.
typedef SharedStoreTransformerD_ SharedStoreTransformerD;
- /// The iterator to store D in shared memory.
- typedef SharedLoadIteratorD_ SharedLoadIteratorD;
+ /// The stream to store D in shared memory.
+ typedef SharedLoadStreamD_ SharedLoadStreamD;
/// typedef typename GemmConfig::EpilogueIterations Iterations;
typedef Iterations_ Iterations;
/// The iterations strides.
@@ -118,14 +118,15 @@ struct GemmEpilogueTraits {
typename GlobalStoreIteratorD::Params iterator_d;
/// The params for the D shared store iterator.
typename SharedStoreIteratorD::Params shared_store_iterator_d;
- /// The params for the D shared load iterator.
- typename SharedLoadIteratorD::Params shared_load_iterator_d;
+ /// The params for the D shared load stream.
+ typename SharedLoadStreamD::Params shared_load_stream_d;
/// The functor params.
typename Functor::Params functor;
/// Setup the params.
template
CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
+
// The parameters for the functor.
int error_code = functor.initialize(desc);
if (error_code) {
@@ -133,20 +134,27 @@ struct GemmEpilogueTraits {
}
// At the end of the H iteration, we jump over a number of columns.
- this->stride_h = desc.ldd * Delta::kH;
+ this->stride_h = desc.D.leading_dim() * Delta::kH;
// Nothing to do here.
this->stride_w = 0;
-
// Setup the params for the global memory iterator for C.
- error_code = iterator_c.initialize(
- reinterpret_cast(desc.d_c), desc.ldc, desc.n, stride_w, Delta::kW);
+ error_code = iterator_c.initialize(desc.C.data(),
+ desc.batch_stride_C,
+ desc.C.leading_dim(),
+ desc.problem_size[1],
+ stride_w,
+ Delta::kW);
if (error_code) {
return error_code;
}
// Setup the params for the global memory iterator for D.
- return iterator_d.initialize(
- reinterpret_cast(desc.d_d), desc.ldd, desc.n, stride_w, Delta::kW);
+ return iterator_d.initialize(desc.D.data(),
+ desc.batch_stride_D,
+ desc.D.leading_dim(),
+ desc.problem_size[1],
+ stride_w,
+ Delta::kW);
}
};
@@ -155,13 +163,20 @@ struct GemmEpilogueTraits {
// The storage for the store iterator.
typename SharedStoreIteratorD::SharedStorage store;
// The storage for the store iterator.
- typename SharedLoadIteratorD::SharedStorage load;
+ typename SharedLoadStreamD::SharedStorage load;
};
/// The shared memory to swizzle the data in the epilogue.
struct SharedStorage {
// The storage for the shared stream D.
StreamSharedStorage shared_stream;
+
+ //
+ //
+ //
+
+ CUTLASS_DEVICE
+ ScalarD* data() { return reinterpret_cast(&shared_stream.load); }
};
};
@@ -192,7 +207,10 @@ struct GemmEpilogueTraitsHelper {
/// The traits class to build the iterator to store to shared memory for D.
typedef GemmSharedStoreTileDTraits<
// The pointer is float.
- typename Functor::Scalar,
+ // typename Functor::Scalar,
+ // Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
+ // In this case Functor::ScalarAccum is needed
+ typename Functor::ScalarAccum,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
@@ -221,7 +239,10 @@ struct GemmEpilogueTraitsHelper {
/// The traits class to build the iterator to load from shared memory for D.
typedef GemmSharedLoadTileDTraits<
// The pointer is float.
- typename Functor::Scalar,
+ // typename Functor::Scalar,
+ // Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation.
+ // In this case Functor::ScalarAccum is needed
+ typename Functor::ScalarAccum,
// The output tile size.
typename GemmConfig_::OutputTile,
// The number of warps.
@@ -242,6 +263,8 @@ struct GemmEpilogueTraitsHelper {
IteratorAdvance::kH,
MemorySpace::kShared>
SharedLoadIteratorD;
+ /// The stream to load D.
+ typedef SharedLoadStream SharedLoadStreamD;
/// The traits class to build the iterator to load data from global memory for C^N.
typedef GemmGlobalTileCdTraits<
@@ -314,8 +337,8 @@ struct SimplifiedGemmEpilogueTraits : public GemmEpilogueTraits<
typename Helper_::SharedStoreIteratorD,
// The shared store transformer for D.
typename Helper_::SharedStoreTransformerD,
- // The iterator to load D from shared memory.
- typename Helper_::SharedLoadIteratorD,
+ // The stream to load D from shared memory.
+ typename Helper_::SharedLoadStreamD,
// The number of iterations.
typename Helper_::Iterations,
// The strides between iterations.
diff --git a/cutlass/gemm/gemm_global_stream.h b/cutlass/gemm/gemm_global_stream.h
index ec675a38fe..6ea72cf30c 100644
--- a/cutlass/gemm/gemm_global_stream.h
+++ b/cutlass/gemm/gemm_global_stream.h
@@ -29,9 +29,10 @@
*/
#pragma once
-#include
-#include
-#include
+#include "cutlass/coord.h"
+#include "cutlass/convert.h"
+#include "cutlass/gemm/gemm_global_tile.h"
+#include "cutlass/tile_allocation.h"
namespace cutlass {
namespace gemm {
@@ -39,6 +40,8 @@ namespace gemm {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <
+ /// Identifies multiplicand
+ GemmOperand::Kind Operand,
/// The load iterator.
typename LoadIterator_,
/// The store iterator to copy to shared memory.
@@ -46,7 +49,9 @@ template <
/// The transformer to be applied after the data has been copied from global memory.
typename Transformer_>
-struct GlobalLoadStreamBase {
+struct GlobalLoadStream {
+ /// Indicates the type of GEMM operand
+ static GemmOperand::Kind const kOperand = Operand;
/// The load iterator.
typedef LoadIterator_ LoadIterator;
/// The transformer.
@@ -75,6 +80,15 @@ struct GlobalLoadStreamBase {
typedef typename LoadIterator::Pointer Pointer;
/// The index.
typedef typename LoadIterator::Index Index;
+ /// The tile
+ typedef typename LoadIterator::Tile Tile;
+
+ /// Shared memory allocation for the tile
+ typedef TileAllocation
+ ThreadblockTileStorage;
+
+ /// Tensor reference to threadblock tile
+ typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef;
/// The params.
struct Params {
@@ -82,56 +96,73 @@ struct GlobalLoadStreamBase {
typename LoadIterator::Params load_iterator;
// The store iterator.
typename StoreIterator::Params store_iterator;
+ // Offset to residue.
+ Index offset_to_residue;
/// Setup the params.
- template
- CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) {
- int error_code = load_iterator.initialize(desc, pointer, ld);
+ CUTLASS_HOST_DEVICE int initialize(Pointer pointer,
+ long long batch_stride,
+ Index ldm,
+ Index _offset_to_residue) {
+
+ offset_to_residue = _offset_to_residue;
+ int error_code = load_iterator.initialize(pointer, batch_stride, ldm);
if (error_code) {
return error_code;
}
-
return store_iterator.initialize();
}
};
- /// The amount of storage in shared memory needed to store the tile.
- typedef typename StoreIterator::SharedStorage SharedStoreStorage;
-
- /// The storage in shared memory needed by that stream.
- union SharedStorage {
- // The load iterator.
- typename LoadIterator::SharedStorage load_iterator;
- // The store iterator.
- SharedStoreStorage store_iterator;
- };
+ /// Contains private storage in shared memory needed by the objects within this class. Note,
+ /// this is *NOT* the shared memory allocation for the GEMM threadblock tile. That necessarily
+ /// exists outside this class, as it is also needed by the warp-level shared=>RF stream.
+ struct SharedStorage {};
+
+ //
+ // Static member functions
+ //
+
+ /// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory
+ CUTLASS_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) {
+ bool const kKstrided =
+ GemmMultiplicandTraits::kKstrided;
+ Coord<3> tile_coord = ProjectOperand::project(coord);
+ return make_Coord(
+ tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC);
+ }
/// Ctor.
- CUTLASS_DEVICE GlobalLoadStreamBase(Params const& params,
- SharedStorage& shared_storage,
- Coord<3> const bounds,
- Coord<3> const& block)
- : load_iterator(params.load_iterator, bounds, block),
+ CUTLASS_DEVICE GlobalLoadStream(
+ Params const& _params,
+ SharedStorage& shared_storage,
+ ThreadblockTileRef const& threadblock_tile_ref,
+ Coord<3> const bounds,
+ Coord<3> const& _threadblock_offset)
+ : params(_params),
+ multiplicand_bounds(project_coordinate(bounds, 1)),
+ threadblock_offset(project_coordinate(_threadblock_offset)),
+ load_iterator(params.load_iterator,
+ project_coordinate(bounds, 1), /*multiplicant_bounds*/
+ project_coordinate(_threadblock_offset) /*threablock_offset*/),
transformer(),
- store_iterator(params.store_iterator, shared_storage.store_iterator)
-
+ store_iterator(params.store_iterator, threadblock_tile_ref.data())
{
+ load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
fetched_fragment.clear();
}
+
/// Load the data from shared memory to the fetch fragment.
- CUTLASS_DEVICE void copy() { iterator_load(load_iterator, fetched_fragment); }
+ CUTLASS_DEVICE void copy() { load_iterator.load_post_increment(fetched_fragment); }
/// Commit the data.
CUTLASS_DEVICE void commit() {
transformer.transform(fetched_fragment, transformed_fragment);
- iterator_store(store_iterator, transformed_fragment);
+ store_iterator.store_post_increment(transformed_fragment);
store_iterator.inc_stage();
}
- /// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1.
- CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); }
-
/// Execute the residue code.
CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) {
load_iterator.residue(k);
@@ -140,9 +171,43 @@ struct GlobalLoadStreamBase {
}
}
- /// Rollback to the beginning of the GEMM-k dimension.
- CUTLASS_DEVICE void rollback() { load_iterator.rollback(); }
+ /// Move to the residue portion.
+ CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) {
+ Index kResidue = k % kTileK;
+ if (kResidue) {
+ residue(kResidue);
+ }
+ load_iterator.add_pointer_offset(params.offset_to_residue * load_iterator.stride_advance());
+ }
+
+ /// Rollback to the beginning of the first tile
+ CUTLASS_DEVICE void rollback(void) {
+ load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset);
+
+ int const kBlock = kOperand == GemmOperand::kA
+ ? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW)
+ : (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW);
+
+ load_iterator.add_pointer_offset(-(params.offset_to_residue + kBlock) *
+ load_iterator.stride_advance());
+ }
+
+ /// Adds a Coord<3> to the underlying global load iterator
+ CUTLASS_DEVICE GlobalLoadStream &operator+=(Coord<3> const &offset) {
+ load_iterator += offset;
+ return *this;
+ }
+ //
+ // Data members
+ //
+
+ /// Parameters
+ Params params;
+ /// Multiplicand bounds
+ Coord<3> multiplicand_bounds;
+ /// Threadblock offset
+ Coord<3> threadblock_offset;
/// The iterator.
LoadIterator load_iterator;
/// The fragment to fetch from shared memory.
@@ -155,28 +220,6 @@ struct GlobalLoadStreamBase {
StoreIterator store_iterator;
};
-////////////////////////////////////////////////////////////////////////////////////////////////////
-
-template <
- /// The load iterator.
- typename LoadIterator_,
- /// The store iterator to copy to shared memory.
- typename StoreIterator_,
- /// The transformer to be applied after the data has been copied from global memory.
- typename Transformer_ = Copy >
-
-struct GlobalLoadStream : public GlobalLoadStreamBase {
- /// The base class.
- typedef GlobalLoadStreamBase Base;
-
- /// Ctor.
- CUTLASS_DEVICE GlobalLoadStream(typename Base::Params const& params,
- typename Base::SharedStorage& shared_storage,
- Coord<3> const& bounds,
- Coord<3> const& block)
- : Base(params, shared_storage, bounds, block) {}
-};
-
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace gemm
} // namespace cutlass
diff --git a/cutlass/gemm/gemm_global_tile.h b/cutlass/gemm/gemm_global_tile.h
index 1cc3b3377a..a355ebea0e 100644
--- a/cutlass/gemm/gemm_global_tile.h
+++ b/cutlass/gemm/gemm_global_tile.h
@@ -27,14 +27,14 @@
*/
#pragma once
-#include
-#include
+#include "cutlass/coord.h"
+#include "cutlass/util/platform.h"
-#include
-#include
-#include
-#include
-#include
+#include "cutlass/gemm/gemm_operand.h"
+#include "cutlass/matrix_traits.h"
+#include "cutlass/predicate_vector.h"
+#include "cutlass/reshape_tile.h"
+#include "cutlass/tile_iterator.h"
namespace cutlass {
namespace gemm {
@@ -80,20 +80,24 @@ struct GemmGlobalTileTraits {
static int const kAccessSize = kAccessSize_;
/// The memory space.
static MemorySpace::Kind const kMemorySpace = MemorySpace::kGlobal;
-
/// The tile shape
- typedef typename ReshapeTile::Tile Tile;
+ typedef Tile_ Tile;
+ /// The vectorized tile shape
+ typedef typename ReshapeTile::Tile VectorizedTile;
/// The threads shape
- typedef typename ReshapeThreads::Threads Threads;
+ typedef typename ReshapeThreads::Threads Threads;
/// The relative offset between two elements in the H/W dimension in adjacent threads.
- typedef Shape<1, 1, Tile::kC> ThreadsDelta;
-
+ typedef Shape<1, 1, VectorizedTile::kC> ThreadsDelta;
/// The strides in each dimension between different loads/stores.
typedef Shape<0, Threads::kH, Threads::kW * kAccessSize> Delta;
+
/// Strides for immediate offset computation
typedef Shape<0, 0, Threads::kW * ThreadsDelta::kW, kAccessSize> ImmediateOffsetStrides;
/// The number of iterations needed to load/store the tile.
- typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kAccessSize>
+ typedef Shape<1,
+ VectorizedTile::kH / Threads::kH,
+ VectorizedTile::kW / Threads::kW,
+ VectorizedTile::kC / kAccessSize>
Iterations;
typedef GemmMultiplicandTraits MultiplicandTraits;
@@ -165,7 +169,6 @@ struct GemmGlobalIteratorAb
Index_> {
/// This class.
typedef GemmGlobalIteratorAb This_; /// The base class.
-
typedef TileLoadIterator
- CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) {
+ CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr,
+ long long stride_d,
+ Index stride_h) {
Index inc_d = 0;
Index inc_advance = 0;
// Move by some columns for each iteration in the H dimension.
@@ -221,99 +227,36 @@ struct GemmGlobalIteratorAb
(Base::Iterations::kH - 1) * inc_h;
}
- // The dimensions of the tile.
- int const kH = TileTraits_::Tile::kH;
- int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
-
- // Move to the residue.
- Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW;
- // The jump in the gemm-k dimension.
- Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1;
-
- // Compute the offset to the residue and how to "come" back.
- Index const kResidue = desc.k % kBlock;
- if (kResidue > 0) {
- move_to_residue_offset = (desc.k - kResidue) * stride;
- } else {
- move_to_residue_offset = (desc.k - kBlock) * stride;
- }
-
- Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance);
+ Base::Params::initialize(
+ ptr, stride_d, stride_h, 1, inc_d, inc_h, 0, inc_advance);
return 0;
}
-
- // The extra offset to control moving to the residue.
- Index move_to_residue_offset;
};
- /// Ctor.
- CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params,
- const Coord<3>& bounds,
- const Coord<3>& block,
- ThreadOffset thread_offset_func = ThreadOffset())
- : params(_params) {
- thread_offset = thread_offset_func();
- // The column.
- Index block_h = thread_offset[1];
- // The contiguous dimension.
- Index block_w = thread_offset[2];
-
- // Add the blocks indices.
- if (kAdvance == IteratorAdvance::kH) {
- block_h += block[1];
- block_w += block[2];
-
- } else {
- block_h += block[2];
- block_w += block[1];
- }
-
- // Setup the pointer.
- params.pointer += (block_h * params.stride_h + block_w);
-
- // Initialize predicates
- initialize_predicates(bounds, make_Coord(0, block_h, block_w));
- }
-
- /// The accessor.
- CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
- int const imm =
- ComputeOffsetFromStrides::get(0, 0, w, c);
- Load::load(value, params.pointer, imm);
- }
-
- /// Increment the pointer in the H dimension.
- CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; }
- /// Increment the pointer in the D dimension.
- CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; }
- /// Increment the pointer to move to the next iteration.
- CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
+ /// Offset of an individual lane from the start of the tile
+ Coord<4> thread_offset;
+ /// The parameters
+ Params params;
+ /// The predicates.
+ PredicateVector predicates;
- /// Initialize the predicates.
- CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) {
+ CUTLASS_HOST_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block_offset) {
// Setup the masks to control loads.
predicates.fill(0);
- int bounds_h, bounds_w;
- if (kAdvance == IteratorAdvance::kH) {
- bounds_w = bounds[2] - block[2];
- bounds_h = bounds[1];
-
- } else {
- bounds_w = bounds[1];
- bounds_h = bounds[2] - block[1];
- }
-
// Fill in the bits of the predicate vector.
for (int d = 0; d < Base::Iterations::kD; ++d) {
for (int h = 0; h < Base::Iterations::kH; ++h) {
for (int w = 0; w < Base::Iterations::kW; ++w) {
for (int c = 0; c < Base::Iterations::kC; ++c) {
- bool flag = w * Base::Delta::kW < bounds_w;
+ bool flag = w * Base::Delta::kW + thread_offset[2] + block_offset[2] < bounds[2];
if (kAdvance == IteratorAdvance::kH) {
- flag = flag && (h * Base::Delta::kH + d * Base::Delta::kD) < bounds_h;
+ flag =
+ flag &&
+ (h * Base::Delta::kH + d * Base::Delta::kD) + thread_offset[1] + block_offset[1] <
+ bounds[1];
} else {
- flag = flag && (h * Base::Delta::kH) < bounds_h;
+ flag = flag && (h * Base::Delta::kH) + thread_offset[1] + block_offset[1] < bounds[1];
}
int const bit = ComputeOffsetFromShape::get(d, h, w, c);
predicates.set(bit, flag);
@@ -323,31 +266,44 @@ struct GemmGlobalIteratorAb
}
}
- /// Move to residue portion.
- CUTLASS_DEVICE void move_to_residue(Index k) {
- // Store the pointer and the predicates.
- stored_pointer = params.pointer;
- stored_predicates = predicates;
-
- // Move the pointer to the residue.
- params.pointer += params.move_to_residue_offset;
+ /// Ctor.
+ CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const& _params,
+ const Coord<3>& bounds,
+ const Coord<3>& threadblock_offset,
+ ThreadOffset thread_offset_func = ThreadOffset())
+ : params(_params) {
+ thread_offset = thread_offset_func();
+ // Setup the pointer.
+ params.pointer += ((threadblock_offset[1] + thread_offset[1]) * params.stride_h +
+ (threadblock_offset[2] + thread_offset[2]));
- // The dimensions of the tile.
- int const kH = TileTraits_::Tile::kH;
- int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize;
+ }
- // The unrolling factor.
- int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW;
+ /// Increment the pointer in the W dimension.
+ CUTLASS_HOST_DEVICE void inc_w() { Base::inc_w(); }
+ /// Increment the pointer in the H dimension.
+ CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; }
+ /// Increment the pointer in the D dimension.
+ CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; }
+ /// Increment the pointer to move to the next iteration.
+ CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; }
- // Clear the predicates for the residue. TODO: We can do something smarter.
- int const kResidue = (int)(k % (Index)kUnroll);
- if (kResidue > 0) {
- residue(kResidue);
- }
+ /// Loads a single fragment element from memory
+ CUTLASS_HOST_DEVICE void load_element(
+ typename Base::AccessType& value, int d, int h, int w, int c) const {
+ int const offset =
+ ComputeOffsetFromStrides::get(0, 0, w, c);
+ Load::load(value, params.pointer, offset);
}
/// That's the residue! Update the predicates.
- CUTLASS_DEVICE void residue(Index k) {
+ CUTLASS_HOST_DEVICE void residue(Index k) {
// The coordinates of the thread.
Index block_h = thread_offset[1];
// The contiguous dimension.
@@ -375,26 +331,63 @@ struct GemmGlobalIteratorAb
}
}
- /// Rollback to beginning of first tile and initialize predicates.
- CUTLASS_DEVICE void rollback() {
- params.pointer = stored_pointer;
- predicates = stored_predicates;
- }
-
- /// Is the iterator valid?
- CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
+ /// Is the valid?
+ CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
int const bit = ComputeOffsetFromShape::get(d, h, w, c);
return predicates[bit];
}
- /// Offset of an individual lane from the start of the tile
- Coord<4> thread_offset;
- /// The parameters
- Params params;
- /// The pointer.
- typename Base::Scalar const* stored_pointer;
- /// The predicates.
- PredicateVector predicates, stored_predicates;
+ /// Adds a vector offset to the iterator
+ CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord<3> const &offset) {
+
+ long long _offset = offset.template dot(
+ make_Coord(params.stride_d, params.stride_h, params.stride_w)
+ );
+
+ params.pointer += _offset;
+ return *this;
+ }
+
+ CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
+
+ CUTLASS_HOST_DEVICE Index stride_advance(void) {
+ Index stride = params.stride_h;
+ if (kAdvance == IteratorAdvance::kW) {
+ stride = params.stride_w;
+ }
+ return stride;
+ }
+
+ template
+ CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
+ typename Base::FragmentIterator frag_iterator(fragment);
+ for (int d = 0; d < Base::Iterations::kD; ++d) {
+ for (int h = 0; h < Base::Iterations::kH; ++h) {
+ for (int w = 0; w < Base::Iterations::kW; ++w) {
+ for (int c = 0; c < Base::Iterations::kC; ++c) {
+ if (valid(d, h, w, c)) {
+ load_element(
+ reinterpret_cast(frag_iterator.at(d, h, w, c)),
+ d,
+ h,
+ w,
+ c);
+ }
+ }
+ if (w < Base::Iterations::kW - 1) {
+ inc_w();
+ }
+ }
+ if (h < Base::Iterations::kH - 1) {
+ inc_h();
+ }
+ }
+ if (d < Base::Iterations::kD - 1) {
+ inc_d();
+ }
+ }
+ inc_advance();
+ }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -433,6 +426,8 @@ struct GemmGlobalIteratorCd : public TileIteratorBasepointer = pointer;
+ // Stride per batch
+ stride_d = batch_stride;
// Each column of the matrix.
- stride_h = TileTraits_::ThreadsDelta::kH * ld;
+ stride_h = TileTraits_::ThreadsDelta::kH * ldm;
// Each thread output 1 column per iteration. The stride between columns is given by the
// number of scalars that are loaded per LDS for B.
- inc_h = ld * TileTraits_::kStrideH;
+ inc_h = ldm * TileTraits_::kStrideH;
inc_advance =
- (ld - ld * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
+ (ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w;
predicate_offset = bound;
predicate_inc_h = TileTraits_::kStrideH;
@@ -464,75 +465,173 @@ struct GemmGlobalIteratorCd : public TileIteratorBase thread_offset;
+ /// The predicates for the row.
+ cutlass::PredicateVector predicates;
/// Ctor.
- CUTLASS_DEVICE GemmGlobalIteratorCd() {}
+ CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
+ const Coord<3>& bounds,
+ const Coord<3>& block_offset,
+ ThreadOffset thread_offset_func = ThreadOffset())
+ : params(_params) {
+ thread_offset = thread_offset_func();
+ // Prepare the vector of predicates.
+ for (int i = 0; i < Base::Iterations::kW; ++i) {
+ predicates.set(i, thread_offset[2] + i * Base::Delta::kW < bounds[2]);
+ }
+ }
/// Ctor.
- CUTLASS_DEVICE GemmGlobalIteratorCd(Params const& params,
- const Coord<3>& bounds,
- const Coord<3>& block,
- int offset = 0,
- int pred_offset = 0,
- ThreadOffset thread_offset_func = ThreadOffset())
- : params(params) {
+ CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params,
+ const Coord<3>& bounds,
+ const Coord<3>& block,
+ int offset = 0,
+ int pred_offset = 0,
+ ThreadOffset thread_offset_func = ThreadOffset())
+ : params(_params) {
thread_offset = thread_offset_func();
// Each warp works on a different column of the tile.
int const h = thread_offset[1] + block[1];
// Each lane writes a different element.
int const w = thread_offset[2] + block[2];
// Setup the pointer.
- this->params.pointer += ((h * params.stride_h + w) + offset);
+ params.pointer += ((h * params.stride_h + w) + offset);
// Prepare the vector of predicates.
for (int i = 0; i < Base::Iterations::kW; ++i) {
predicates.set(i, w + i * Base::Delta::kW < bounds[2]);
}
- this->params.predicate_offset -= (h + pred_offset);
- }
-
- /// The accessor.
- CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const {
- int const imm =
- ComputeOffsetFromStrides::get(0, 0, w, c);
- Load::load(value, params.pointer, imm);
+ params.predicate_offset -= (h + pred_offset);
}
/// Increment the pointer in the C dimension.
- CUTLASS_DEVICE void inc_c() {}
+ CUTLASS_HOST_DEVICE void inc_c() {}
/// Increment the pointer in the W dimension.
- CUTLASS_DEVICE void inc_w() {}
+ CUTLASS_HOST_DEVICE void inc_w() {}
/// Increment the pointer in the H dimension.
- CUTLASS_DEVICE void inc_h() {
+ CUTLASS_HOST_DEVICE void inc_h() {
params.pointer += params.inc_h;
params.predicate_offset -= params.predicate_inc_h;
}
/// Increment the pointer in the D dimension.
- CUTLASS_DEVICE void inc_d() {}
+ CUTLASS_HOST_DEVICE void inc_d() {}
/// Increment the pointer to move to the next iteration.
- CUTLASS_DEVICE void inc_advance() {
+ CUTLASS_HOST_DEVICE void inc_advance() {
params.pointer += params.inc_advance;
- this->params.predicate_offset -= params.predicate_inc_advance;
+ params.predicate_offset -= params.predicate_inc_advance;
}
- /// The accessor.
- CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) {
- int const imm =
- ComputeOffsetFromStrides::get(0, 0, w, c);
- Store::store(
- value, params.pointer, imm);
+ /// Adds a vector offset to the iterator
+ CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord<3> const &offset) {
+ long long _offset = offset.template dot(
+ make_Coord(params.stride_d, params.stride_h, 1)
+ );
+ params.pointer += _offset;
+ return *this;
+ }
+
+ /// Loads a single fragment element from memory.
+ CUTLASS_HOST_DEVICE void load_element(
+ typename Base::AccessType& value, int d, int h, int w, int c) const {
+ int const offset =
+ ComputeOffsetFromStrides::get(d, h, w, c);
+ Load::load(value, params.pointer, offset);
}
- /// Test the validity of the iterator.
- CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const {
+ /// Stores a single fragment element into memory.
+ CUTLASS_HOST_DEVICE void store_element(
+ typename Base::AccessType const& value, int d, int h, int w, int c) {
+ int const offset =
+ ComputeOffsetFromStrides::get(d, h, w, c);
+ Store::store(value, params.pointer, offset);
+ }
+
+ /// Test the validity of the
+ CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const {
return predicates.at(w) && params.predicate_offset > 0;
}
- /// The predicates for the row.
- cutlass::PredicateVector predicates;
+ /// add pointer offset
+ CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; }
+
+ /// Loads and increments iterator
+ template
+ CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) {
+ typename Base::FragmentIterator frag_iterator(fragment);
+ for (int d = 0; d < Base::Iterations::kD; ++d) {
+ for (int h = 0; h < Base::Iterations::kH; ++h) {
+ for (int w = 0; w < Base::Iterations::kW; ++w) {
+ for (int c = 0; c < Base::Iterations::kC; ++c) {
+ if (valid(d, h, w, c)) {
+ load_element(
+ reinterpret_cast(frag_iterator.at(d, h, w, c)),
+ d,
+ h,
+ w,
+ c);
+ }
+ }
+ if (w < Base::Iterations::kW - 1) {
+ inc_w();
+ }
+ }
+ if (h < Base::Iterations::kH - 1) {
+ inc_h();
+ }
+ }
+ if (d < Base::Iterations::kD - 1) {
+ inc_d();
+ }
+ }
+ inc_advance();
+ }
+
+ template
+ CUTLASS_HOST_DEVICE void store_post_increment(Fragment& fragment) {
+ typename Base::FragmentIterator frag_iterator(fragment);
+ for (int d = 0; d < Base::Iterations::kD; ++d) {
+ for (int h = 0; h < Base::Iterations::kH; ++h) {
+ for (int w = 0; w < Base::Iterations::kW; ++w) {
+ for (int c = 0; c < Base::Iterations::kC; ++c) {
+ if (valid(d, h, w, c)) {
+ store_element(
+ reinterpret_cast(frag_iterator.at(d, h, w, c)),
+ d,
+ h,
+ w,
+ c);
+ }
+ }
+ if (w < Base::Iterations::kW - 1) {
+ inc_w();
+ }
+ }
+ if (h < Base::Iterations::kH - 1) {
+ inc_h();
+ }
+ }
+ if (d < Base::Iterations::kD - 1) {
+ inc_d();
+ }
+ }
+ inc_advance();
+ }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/cutlass/gemm/gemm_operand.h b/cutlass/gemm/gemm_operand.h
index 737f993f01..2b4dcdc916 100644
--- a/cutlass/gemm/gemm_operand.h
+++ b/cutlass/gemm/gemm_operand.h
@@ -28,9 +28,9 @@
*/
#pragma once
-#include
-#include
-#include