|
9 | 9 | #include <hipblaslt/hipblaslt.h>
|
10 | 10 | #include <hipblaslt/hipblaslt-ext.hpp>
|
11 | 11 |
|
12 |
| -#define max_workspace_size 2 * 128 * 1024 * 1024 |
13 |
| - |
14 | 12 | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
15 | 13 | #define CHECK_CONTIGUOUS(x) \
|
16 | 14 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
|
36 | 34 | }
|
37 | 35 | #endif
|
38 | 36 |
|
| 37 | +static void* workspace = nullptr; |
| 38 | +static size_t workspace_size; |
| 39 | + |
| 40 | +// Copied from |
| 41 | +// https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/cuda/tunable/GemmHipblaslt.h |
| 42 | +static size_t get_hipblaslt_workspace_size() { |
| 43 | + static const char* env = getenv("HIPBLASLT_WORKSPACE_SIZE"); |
| 44 | + // 256MB is max workspace size allowed for hipblaslt |
| 45 | + // hipblaslt-bench uses 32MB |
| 46 | + // recommendation from hipblaslt author was 76MB |
| 47 | + size_t workspace_size = 32 * 1024; // going with 32MB |
| 48 | + if (env) { |
| 49 | + try { |
| 50 | + workspace_size = std::stoi(env); |
| 51 | + } catch (std::invalid_argument const& e) { |
| 52 | + TORCH_WARN("invalid HIPBLASLT_WORKSPACE_SIZE,", |
| 53 | + " using default workspace size of ", workspace_size, " KiB."); |
| 54 | + } catch (std::out_of_range const& e) { |
| 55 | + TORCH_WARN("HIPBLASLT_WORKSPACE_SIZE out of range,", |
| 56 | + " using default workspace size of ", workspace_size, " KiB."); |
| 57 | + } |
| 58 | + } |
| 59 | + return workspace_size * 1024; |
| 60 | +} |
| 61 | + |
| 62 | +void create_workspace() { |
| 63 | + workspace_size = get_hipblaslt_workspace_size(); |
| 64 | + if (workspace_size > 0) |
| 65 | + CHECK_HIP_ERROR(hipMalloc(&workspace, workspace_size)); |
| 66 | +} |
| 67 | + |
39 | 68 | torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
|
40 | 69 | torch::Tensor& scaleA, torch::Tensor& scaleB,
|
41 | 70 | torch::Tensor& scaleD, int algo_idx) {
|
@@ -116,7 +145,7 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
|
116 | 145 | auto stream = at::cuda::getCurrentCUDAStream();
|
117 | 146 |
|
118 | 147 | hipblaslt_ext::GemmPreference gemmPref;
|
119 |
| - gemmPref.setMaxWorkspaceBytes(0); |
| 148 | + gemmPref.setMaxWorkspaceBytes(workspace_size); |
120 | 149 | hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
121 | 150 | transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
122 | 151 | HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ,
|
@@ -173,7 +202,7 @@ torch::Tensor fp8_gemm(torch::Tensor& a, torch::Tensor& b,
|
173 | 202 | TORCH_CUDABLAS_CHECK(
|
174 | 203 | hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo));
|
175 | 204 |
|
176 |
| - CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); |
| 205 | + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, workspace)); |
177 | 206 | CHECK_HIPBLASLT_ERROR(gemm.run(stream));
|
178 | 207 |
|
179 | 208 | // hipFree(d_scaleA);
|
@@ -260,7 +289,7 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
|
260 | 289 | auto stream = at::cuda::getCurrentCUDAStream();
|
261 | 290 |
|
262 | 291 | hipblaslt_ext::GemmPreference gemmPref;
|
263 |
| - gemmPref.setMaxWorkspaceBytes(0); |
| 292 | + gemmPref.setMaxWorkspaceBytes(workspace_size); |
264 | 293 | hipblaslt_ext::Gemm gemm(handle, transpose_a ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
265 | 294 | transpose_b ? HIPBLAS_OP_T : HIPBLAS_OP_N,
|
266 | 295 | HIP_R_8F_E4M3_FNUZ, HIP_R_8F_E4M3_FNUZ, HIP_R_16F,
|
@@ -314,7 +343,7 @@ torch::Tensor fp8_gemm_16(torch::Tensor& a, torch::Tensor& b,
|
314 | 343 | TORCH_CUDABLAS_CHECK(
|
315 | 344 | hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, tmpAlgo));
|
316 | 345 |
|
317 |
| - CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, nullptr)); |
| 346 | + CHECK_HIPBLASLT_ERROR(gemm.initialize(tmpAlgo[0].algo, workspace)); |
318 | 347 | CHECK_HIPBLASLT_ERROR(gemm.run(stream));
|
319 | 348 |
|
320 | 349 | // hipFree(d_scaleA);
|
|
0 commit comments