Skip to content

Commit

Permalink
Make cutlass::gemm::device::GemmArray usable (NVIDIA#295)
Browse files Browse the repository at this point in the history
* Fix the build of cutlass/gemm/device/gemm_array.h and add a demo for GemmArray

* Add a reference to GemmArray to the docs

Co-authored-by: Ivan Komarov <dfyz@yandex-team.ru>
  • Loading branch information
dfyz and Ivan Komarov authored Feb 18, 2022
1 parent 3cfa5db commit e96f005
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 12 deletions.
134 changes: 124 additions & 10 deletions examples/05_batched_gemm/batched_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,16 @@

#include "cutlass/cutlass.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/gemm/device/gemm_array.h"
#include "cutlass/gemm/device/gemm_batched.h"

#pragma warning( disable : 4503)

/*
This example demonstrates how to use cutlass to compute a batched strided gemm.
This example demonstrates how to use cutlass to compute a batched gemm in two different ways:
1. By specifying pointers to the first matrices of the batch and the stride between the consecutive
matrices of the batch (this is called a strided batched gemm).
2. By copying pointers to all matrices of the batch to the device memory (this is called an array gemm).
In this example, both A and B matrix are non-transpose and column major matrix
batched_C = batched_A x batched_B
As an example, matrix C can be seen as
Expand Down Expand Up @@ -89,6 +93,45 @@ The stride (batch_stride_C) between the first element of two batches is k
*/

cudaError_t cutlass_array_sgemm(
int m,
int n,
int k,
float alpha,
float const * const *A,
int lda,
float const * const *B,
int ldb,
float * const *C,
int ldc,
float beta,
int batch_count) {

using Gemm = cutlass::gemm::device::GemmArray<
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor,
float, cutlass::layout::ColumnMajor
>;

Gemm gemm_op;

cutlass::Status status = gemm_op({
{m, n, k},
A, lda,
B, ldb,
C, ldc,
C, ldc,
{alpha, beta},
batch_count
});

if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}

return cudaSuccess;
}

cudaError_t cutlass_strided_batched_sgemm(
int m,
int n,
Expand Down Expand Up @@ -188,7 +231,10 @@ cudaError_t strided_batched_gemm_nn_reference(
return result;
}

int main() {
cudaError_t run_batched_gemm(bool use_array) {

const char* gemm_desc = use_array ? "array" : "strided batched";
std::cout << "Running " << gemm_desc << " gemm" << std::endl;

// Arbitrary problem size
int const m = 520;
Expand Down Expand Up @@ -293,11 +339,69 @@ int main() {
}

// run cutlass
result = cutlass_strided_batched_sgemm(
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
beta, batch_count);
if (result != cudaSuccess)
return result;
if (use_array) {
// allocate the host memory for the pointers to the matrices of the batch
std::vector<float*> host_ptr_A(batch_count);
std::vector<float*> host_ptr_B(batch_count);
std::vector<float*> host_ptr_C(batch_count);

// permute the batch elements to emphasize that GemmArray does not depend on matrices being separated by a fixed stride
std::vector<size_t> permutation = {14, 11, 3, 10, 1, 13, 9, 4, 6, 16, 8, 15, 7, 12, 0, 2, 5};
for (size_t b_idx = 0; b_idx < batch_count; b_idx++) {
host_ptr_A[b_idx] = A + permutation[b_idx] * batch_stride_A;
host_ptr_B[b_idx] = B + permutation[b_idx] * batch_stride_B;
host_ptr_C[b_idx] = C + permutation[b_idx] * batch_stride_C;
}

// allocate the corresponding device memory
float const **ptr_A;
float const **ptr_B;
float **ptr_C;

result = cudaMalloc(&ptr_A, batch_count * sizeof(float*));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
result = cudaMalloc(&ptr_B, batch_count * sizeof(float*));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}
result = cudaMalloc(&ptr_C, batch_count * sizeof(float*));
if (result != cudaSuccess) {
std::cerr << "cudaMalloc result = " << result << std::endl;
return result;
}

// copy the matrix pointers to the device
result = cudaMemcpy(ptr_A, host_ptr_A.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
result = cudaMemcpy(ptr_B, host_ptr_B.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}
result = cudaMemcpy(ptr_C, host_ptr_C.data(), batch_count * sizeof(float*), cudaMemcpyHostToDevice);
if (result != cudaSuccess) {
std::cerr << "cudaMemcpy result = " << result << std::endl;
return result;
}

result = cutlass_array_sgemm(m, n, k, alpha, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, beta, batch_count);

if (result != cudaSuccess)
return result;
} else {
result = cutlass_strided_batched_sgemm(
m, n, k, alpha, A, lda, batch_stride_A, B, ldb, batch_stride_B, C, ldc, batch_stride_C,
beta, batch_count);
if (result != cudaSuccess)
return result;
}

// copy device memory to host
result = cudaMemcpy(result_C.data(), C, count_C * sizeof(float), cudaMemcpyDeviceToHost);
Expand All @@ -314,7 +418,7 @@ int main() {

// Expect bit-level accuracy for this simple example
if (ref_C != result_C) {
std::cout << "CUTLASS strided batched gemm does not run correctly" << std::endl;
std::cout << "CUTLASS " << gemm_desc << " gemm does not run correctly" << std::endl;
return cudaErrorUnknown;
}

Expand All @@ -335,9 +439,19 @@ int main() {
return result;
}

return result;
}

int main() {

if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
cudaError_t result = cudaSuccess;
for (bool use_array : {false, true}) {
result = run_batched_gemm(use_array);
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
} else {
break;
}
}

// Exit.
Expand Down
4 changes: 2 additions & 2 deletions include/cutlass/gemm/device/gemm_array.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,8 +376,8 @@ class GemmArray {

cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(
args.problem_size,
args.batch_count,
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK});
{ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
args.batch_count);

// Initialize the Params structure
params_ = typename GemmKernel::Params{
Expand Down
1 change: 1 addition & 0 deletions media/docs/gemm_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ has semantics similar to cuBLAS.

The device-wide GEMM API is embodied by the following operators:
- [cutlass::gemm::device::Gemm](/include/cutlass/gemm/device/gemm.h) - basic GEMM operation
- [cutlass::gemm::device::GemmArray](/include/cutlass/gemm/device/gemm_array.h) - batched GEMM operation in which input matrices are read from arrays of pointers
- [cutlass::gemm::device::GemmBatched](/include/cutlass/gemm/device/gemm_batched.h) - batched GEMM operation in which input matrices are separated by a constant stride
- [cutlass::gemm::device::GemmSplitKParallel](/include/cutlass/gemm/device/gemm_splitk_parallel.h) - GEMM operation that partitions the GEMM K dimension then launches a separate reduction kernel

Expand Down

0 comments on commit e96f005

Please sign in to comment.