Skip to content

Commit e45129d

Browse files
charlifugshtras
andauthored
Allocate workspace for hipblaslt fp8 gemm. (vllm-project#78)
* Initializing hipblaslt workspace for fp8 gemms * make workspace size configurable * assign default value for worksapce pointer * fix clang-format * fix clang-format --------- Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
1 parent 52df169 commit e45129d

File tree

4 files changed

+40
-6
lines changed

4 files changed

+40
-6
lines changed

csrc/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
124124
torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
125125
torch::Tensor& scaleA, torch::Tensor& scaleB,
126126
int algo_idx);
127+
128+
void create_workspace();
127129
#endif
128130

129131
void moe_align_block_size(torch::Tensor topk_ids, int num_experts,

csrc/pybind.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
7272
#ifdef USE_ROCM
7373
ops.def("fp8_gemm", &fp8_gemm, "fp8 GEMM with fp8 output");
7474
ops.def("fp8_gemm_16", &fp8_gemm_16, "fp8 GEMM with fp16 output");
75+
ops.def("create_workspace", &create_workspace,
76+
"Create workspace for fp8 GEMM");
7577
#endif
7678

7779
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,

csrc/quantization/fp8/amd/gemm_kernel.cu

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include <hipblaslt/hipblaslt.h>
1010
#include <hipblaslt/hipblaslt-ext.hpp>
1111

12-
#define max_workspace_size 2 * 128 * 1024 * 1024
13-
1412
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
1513
#define CHECK_CONTIGUOUS(x) \
1614
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
@@ -36,6 +34,37 @@
3634
}
3735
#endif
3836

37+
static void* workspace = nullptr;
38+
static size_t workspace_size;
39+
40+
// Copied from
41+
// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/cuda/tunable/GemmHipblaslt.h
42+
static size_t get_hipblaslt_workspace_size() {
43+
static const char* env = getenv("HIPBLASLT_WORKSPACE_SIZE");
44+
// 256MB is max workspace size allowed for hipblaslt
45+
// hipblaslt-bench uses 32MB
46+
// recommendation from hipblaslt author was 76MB
47+
size_t workspace_size = 32 * 1024; // going with 32MB
48+
if (env) {
49+
try {
50+
workspace_size = std::stoi(env);
51+
} catch (std::invalid_argument const& e) {
52+
TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,",
53+
" using default workspace size of ", workspace_size, " KiB.");
54+
} catch (std::out_of_range const& e) {
55+
TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,",
56+
" using default workspace size of ", workspace_size, " KiB.");
57+
}
58+
}
59+
return workspace_size * 1024;
60+
}
61+
62+
void create_workspace() {
63+
workspace_size = get_hipblaslt_workspace_size();
64+
if (workspace_size > 0)
65+
CHECK_HIP_ERROR(hipMalloc(&workspace, workspace_size));
66+
}
67+
3968
torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
4069
torch::Tensor& scaleA, torch::Tensor& scaleB,
4170
torch::Tensor& scaleD, int algo_idx) {
@@ -116,7 +145,7 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
116145
auto stream = at::cuda::getCurrentCUDAStream();
117146

118147
hipblaslt_ext::GemmPreference gemmPref;
119-
gemmPref.setMaxWorkspaceBytes(0);
148+
gemmPref.setMaxWorkspaceBytes(workspace_size);
120149
hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N,
121150
transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N,
122151
HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ,
@@ -173,7 +202,7 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
173202
TORCH_CUDABLAS_CHECK(
174203
hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo));
175204

176-
CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr));
205+
CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, workspace));
177206
CHECK_HIPBLASLT_ERROR(gemm.run(stream));
178207

179208
// hipFree(d_scaleA);
@@ -260,7 +289,7 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
260289
auto stream = at::cuda::getCurrentCUDAStream();
261290

262291
hipblaslt_ext::GemmPreference gemmPref;
263-
gemmPref.setMaxWorkspaceBytes(0);
292+
gemmPref.setMaxWorkspaceBytes(workspace_size);
264293
hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N,
265294
transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N,
266295
HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F,
@@ -314,7 +343,7 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
314343
TORCH_CUDABLAS_CHECK(
315344
hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo));
316345

317-
CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr));
346+
CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, workspace));
318347
CHECK_HIPBLASLT_ERROR(gemm.run(stream));
319348

320349
// hipFree(d_scaleA);

vllm/model_executor/layers/quantization/fp8_rocm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class Fp8RocmConfig(QuantizationConfig):
2626
def __init__(self) -> None:
2727
self._tuned = {}
2828
gemm_type = os.getenv("FP8_GEMM", "fp8_16")
29+
vllm_ops.create_workspace()
2930
if gemm_type == "fp8_8":
3031
self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8
3132
tuned_filename = "/tmp/tuned_fp8_8.csv"

0 commit comments

Comments
 (0)