-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin #7701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
mgoin
merged 27 commits into
vllm-project:main
from
neuralmagic:lwilkinson/machete-end2end
Sep 23, 2024
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
458d69e
squash-patch changes
LucasWilkinson 1ee3608
remove gptq support
LucasWilkinson ab7507e
formatting + fixes
LucasWilkinson 68ff26d
add gptq_marlin support back
LucasWilkinson 7b9e8b2
remove extra prints
LucasWilkinson 30f1056
add machete act ordering
LucasWilkinson 3bbb902
udpate heuristic
LucasWilkinson 196a9f2
add to tests
LucasWilkinson 38f5b84
update benchmark
LucasWilkinson c59449b
tweak for llama 405b
LucasWilkinson 3048911
env var for disabling kernels
LucasWilkinson df7c4c0
format + mypy
LucasWilkinson 6f3f707
yapf format
LucasWilkinson 90b8e03
refactor
LucasWilkinson c264c7a
add g_idx back
LucasWilkinson 2d25a9a
clean-up
LucasWilkinson 62508c5
review comments
LucasWilkinson 84cfdb2
fix codespell
LucasWilkinson c452a86
TorchDynamo Compatability
LucasWilkinson 096dd4a
add permute cols opcheck
LucasWilkinson a98f691
fix correctness test
LucasWilkinson 7c02bcf
bug in filtering kernels by compute capability
LucasWilkinson 95a85c9
Merge remote-tracking branch 'origin/main' into lwilkinson/machete-en…
LucasWilkinson a019473
add requirements.txt
LucasWilkinson 306b283
Merge branch 'main' into lwilkinson/machete-end2end
mgoin e32bfc5
[dbrx] refactor dbrx experts to extend FusedMoe class (#8518)
divakar-amd 05752e9
[Kernel][Bugfix] Delete some more useless code in marlin_moe_ops.cu (…
tlrmchlsmth File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
pandas |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#include <torch/all.h> | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
|
||
#include <cuda_fp16.h> | ||
|
||
static constexpr int default_threads = 256; | ||
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } | ||
|
||
// For a given "a" of size [M,K] performs a permutation of the K columns based | ||
// on the given "perm" indices. | ||
// Currently only supports 16bit types (since we permute half types) | ||
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, | ||
int const* __restrict__ perm_int_ptr, | ||
int4* __restrict__ out_int4_ptr, int size_m, | ||
int size_k, int block_rows) { | ||
int start_row = block_rows * blockIdx.x; | ||
int finish_row = start_row + block_rows; | ||
if (finish_row > size_m) { | ||
finish_row = size_m; | ||
} | ||
int cur_block_rows = std::max(finish_row - start_row, 0); | ||
|
||
int row_stride = size_k * sizeof(half) / 16; | ||
|
||
auto permute_row = [&](int row) { | ||
int iters = size_k / default_threads; | ||
int rest = size_k % default_threads; | ||
|
||
int offset = row * row_stride; | ||
|
||
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset); | ||
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset); | ||
|
||
int base_k = 0; | ||
|
||
for (int i = 0; i < iters; i++) { | ||
int cur_k = base_k + threadIdx.x; | ||
int src_pos = perm_int_ptr[cur_k]; | ||
|
||
out_half[cur_k] = a_row_half[src_pos]; | ||
|
||
base_k += default_threads; | ||
} | ||
|
||
if (rest) { | ||
if (threadIdx.x < rest) { | ||
int cur_k = base_k + threadIdx.x; | ||
int src_pos = perm_int_ptr[cur_k]; | ||
|
||
out_half[cur_k] = a_row_half[src_pos]; | ||
} | ||
} | ||
}; | ||
|
||
for (int i = 0; i < cur_block_rows; i++) { | ||
int cur_row = start_row + i; | ||
if (cur_row < size_m) { | ||
permute_row(cur_row); | ||
} | ||
} | ||
} | ||
|
||
// More efficient version of A[..., perm] | ||
// taken from gptq_marlin.cu | ||
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); | ||
auto dev = A.get_device(); | ||
auto stream = at::cuda::getCurrentCUDAStream(dev); | ||
|
||
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, | ||
"Currently only 16bit types are supported"); | ||
TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); | ||
TORCH_CHECK(A.size(-1) % 8 == 0, | ||
"A columns must be a multiple of 8 (128bits)"); | ||
auto A_2d = A.view({-1, A.size(-1)}); | ||
|
||
torch::Tensor D = torch::empty_like(A); | ||
int sms; | ||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); | ||
int block_rows = div_ceil(A_2d.size(0), sms); | ||
permute_cols_kernel<<<sms, default_threads, 0, stream>>>( | ||
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()), | ||
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()), | ||
A_2d.size(0), A_2d.size(1), block_rows); | ||
return D; | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.