Skip to content

Commit

Permalink
[Build] Guard against older CUDA versions when building CUTLASS 3.x k…
Browse files Browse the repository at this point in the history
…ernels (#5168)
  • Loading branch information
tlrmchlsmth authored Jun 1, 2024
1 parent 6575791 commit 1197e02
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
10 changes: 8 additions & 2 deletions csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>

#if defined CUDA_VERSION && CUDA_VERSION >= 12000

#include <torch/extension.h>

#include <ATen/cuda/CUDAContext.h>
Expand All @@ -6,8 +12,6 @@
#include <sstream>
#include <vector>

// clang-format will break include orders
// clang-format off
#include "cutlass/cutlass.h"

#include "cute/tensor.hpp"
Expand Down Expand Up @@ -241,3 +245,5 @@ void cutlass_scaled_mm_dq_sm90(torch::Tensor& out, torch::Tensor const& a,
}
}
}

#endif
11 changes: 10 additions & 1 deletion csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <cudaTypedefs.h>

#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

void cutlass_scaled_mm_dq_sm75(torch::Tensor& c, torch::Tensor const& a,
Expand All @@ -17,10 +18,12 @@ void cutlass_scaled_mm_dq_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);

#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_dq_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
#endif

void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
Expand Down Expand Up @@ -51,7 +54,13 @@ void cutlass_scaled_mm_dq(torch::Tensor& c, torch::Tensor const& a,

if (version_num >= 90) {
// Hopper

// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_dq_sm90(c, a, b, a_scales, b_scales);
#else
cutlass_scaled_mm_dq_sm80(c, a, b, a_scales, b_scales);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_dq_sm89(c, a, b, a_scales, b_scales);
Expand Down

0 comments on commit 1197e02

Please sign in to comment.