Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Update CUTLASS to 3.5.1 #7085

Merged
merged 4 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
CUDA graph fix -- use torch::empty for workspace
  • Loading branch information
tlrmchlsmth committed Aug 5, 2024
commit e80e8cbd7617ec957d61b3de06b3f4b9692c1514
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

Expand Down
7 changes: 3 additions & 4 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"

#include "cutlass/util/device_memory.h"

#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
Expand Down Expand Up @@ -301,12 +299,13 @@ inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);

auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get(), stream);
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}

Expand Down
7 changes: 3 additions & 4 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"

#include "cutlass/util/device_memory.h"

#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
Expand Down Expand Up @@ -274,11 +272,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(gemm_op.can_implement(args));

size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);

auto stream = at::cuda::getCurrentCUDAStream(a.get_device());

cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}

Expand Down
Loading