Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Rs/marlin downstream v0.3.2 (#43)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Feldman <afeldman@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com>
Co-authored-by: alexm <alexm@neuralmagic.com>
  • Loading branch information
4 people authored Feb 22, 2024
1 parent acb8615 commit 4b44479
Show file tree
Hide file tree
Showing 15 changed files with 1,563 additions and 17 deletions.
9 changes: 9 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ torch::Tensor awq_dequantize(
int split_k_iters,
int thx,
int thy);

torch::Tensor marlin_gemm(
torch::Tensor &a,
torch::Tensor &b_q_weight,
torch::Tensor &b_scales,
torch::Tensor &workspace,
int64_t size_m,
int64_t size_n,
int64_t size_k);
#endif

void squeezellm_gemm(
Expand Down
1 change: 1 addition & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#ifndef USE_ROCM
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
#endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
Expand Down
Loading

0 comments on commit 4b44479

Please sign in to comment.