-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] add bfloat16 support for gptq marlin kernel #4788
[Kernel] add bfloat16 support for gptq marlin kernel #4788
Conversation
@alexm-nm can you review this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for doing the work of adding bfloat16 to marlin. Left some comments.
@@ -9,6 +9,10 @@ | |||
#include <cuda_runtime.h> | |||
#include <iostream> | |||
|
|||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 | |||
#include <cuda_bf16.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why it is necessary to check here that SM >= 8.0? Shouldn't the "include <cuda_bf16.h> work regardless?
@@ -38,6 +42,7 @@ constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } | |||
// No support for async | |||
#else | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: formatting
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; | ||
} | ||
} | ||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a problematic way to add bfloat16 support to marlin, since we should be able to compile the marlin module for both float16 and bfloat16 at the same time. Could you restructure the code to use a template parameter instead to the Marlin<...> kernel and use the template parameter for all of the functions required to have a templated type. If you don't have time, then I can take over and fix it for you. Tell me what works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I would restructure it soon.
@alexm-nm I have restructured code. Can you review it again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jinzhen-lin this looks much better with the template param! I left some minor comments. Could you also add a test to test_gptq_marlin.py with some models that run with dtype.bfloat16 (so we have correctness verified on every change going forward). Again, thanks for the help!
size_k, workspace.data_ptr(), num_bits, has_act_order, is_k_full, | ||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), | ||
thread_k, thread_n, sms, gptq_marlin::max_par); | ||
} else if (a.scalar_type() == at::ScalarType::BFloat16) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You had an #ifdef to check for CUDA_ARCH >= 8 above whether you access nv_bfloat16. I suppose it generates a compilation error if you don't have the ifdef. I think you should have an ifdef here as well to disable the bfloat16 case so the code compiles for SM < 8.
}; | ||
|
||
template <> | ||
class ScalarType<half> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks much better! Thanks for doing this.
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); | ||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); | ||
|
||
fp32_intermediates[0] -= 8388736.f; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On what code this dequant_8bit is based? Maybe you can document the reference you used.
@bnellnm could you do a quick pass on the template changes. |
__device__ inline FragB dequant_4bit(int q) { | ||
template <typename scalar_t> | ||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) { | ||
throw std::runtime_error("unsupported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what the standard is but I think most checks in the code use TORCH_CHECK
rather than throw
.
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) | ||
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), | ||
"r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); | ||
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be safer to make the else clause a static_assert
so if a new type were added, this function would not silently compile with an empty body, i.e.
} else {
static_assert(std::is_same<scalar_t, half>::value);
asm volatile(...);
}
|
||
template <typename scalar_t> | ||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) { | ||
throw std::runtime_error("unsupported"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TORCH_CHECK?
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), | ||
thread_k, thread_n, sms, gptq_marlin::max_par); | ||
} else { | ||
throw std::runtime_error("gpt_marlin_gemm only supports bfloat16 and float16"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TORCH_CHECK here too?
The template changes look good. I had a few minor comments. Mostly the use of TORCH_CHECK over throw (which I think is more "standard"). |
@jinzhen-lin I think your code is in good state to land after addressing last comments. |
@alexm-nm @bnellnm All previous comments have been fixed. As for test in
|
@jinzhen-lin thanks for adding the tests and fixing all comments. @robertgshaw2-neuralmagic looks good to me to proceed forward. |
Thanks all! |
Some models would overflow when using fp16 inference (e.g. Deepseek-V2), so we should add bfloat16 support for quantization kernel. This PR add bfloat16 support for gptq marlin kernel.
Unlike gptq kernel in #4781 , gptq marlin kernel doesn't use
atomicAdd
, so the performance of bfloat16 is close to float16.Related issue: #2149
Main changes: