Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers save/load compatibility and inference kernels #3

Merged
merged 111 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
111 commits
Select commit Hold shift + click to select a range
e159c7f
Added 1x16 CUDA kernel
efrantar Jan 27, 2024
9ec6b70
Conversion script
Jan 16, 2024
395a9f6
black, src, indent
Jan 16, 2024
8e80100
pack_int_data
Jan 16, 2024
4cabbc3
deprecated double quant, scales
Jan 14, 2024
2c97a9f
estimate_nbits_per_parameter
Jan 14, 2024
7864596
scales fix
Jan 14, 2024
3f66363
First triton kernel
Jan 16, 2024
07986d9
black, isort
Jan 16, 2024
56359db
Quantization refactoring started
Jan 16, 2024
6ad6a6b
restored double quant
Jan 16, 2024
46be31d
estimate_nbits_per_parameter
Jan 16, 2024
6ee4353
less diff aq_engine
Jan 16, 2024
9e917f1
bias processing
Jan 16, 2024
e281962
removed debug prints
Jan 16, 2024
8c35e89
additional kwargs in config
Jan 16, 2024
a795ee6
removed matmul kernel
Jan 16, 2024
10c0a7e
packing and unpacking integers
Jan 16, 2024
c41f09e
packs and unpacks
Jan 16, 2024
ed1430f
undoing
Jan 17, 2024
eb89781
FinalizedQuantizedLinear
Jan 17, 2024
96301f7
tied up and saving
Jan 17, 2024
c19a5f6
fixed saving
Jan 17, 2024
0013d8e
removed unsupported kwargs
Jan 17, 2024
89d6085
triton kernel again
Jan 18, 2024
444f788
bias in triton
Jan 18, 2024
0f5483a
renamed smt
Jan 18, 2024
a9e65bb
new configuration ides
Jan 18, 2024
27e3855
inference file copying
Jan 18, 2024
0f37ea6
separated saving
Jan 18, 2024
e3b623f
Fixed cloning
Jan 20, 2024
fe9a9f2
skernel
Jan 20, 2024
164c5d1
better saving
Jan 20, 2024
db0d5a2
isort
Jan 20, 2024
0d7e1af
removed unnecessary dependencies
Jan 20, 2024
9f83593
lm_eval tokenizer trust remote code
Jan 21, 2024
990d5b2
llama tokenizer
Jan 21, 2024
b0b59f5
Deleted llama tokenizers
Jan 22, 2024
fd8395f
faster triton kernel
Jan 22, 2024
e519193
has_bias tl constexpr
Jan 22, 2024
9e89f61
cpp_kernel benchmarks
Jan 26, 2024
e673e0b
better order
Jan 26, 2024
0d69cf6
better compile flags
Jan 26, 2024
59e1f3c
removed unnecessary pragmas
Jan 26, 2024
e98ed7c
fixed stuff
Jan 27, 2024
c245c5e
removed test function
BlackSamorez Jan 27, 2024
147ed79
icpx
BlackSamorez Jan 27, 2024
a5ae331
inference_lib
BlackSamorez Jan 28, 2024
11e5dfd
inference lib done
Jan 28, 2024
e5e9ee9
Correct modeling_llama.py
Jan 28, 2024
48d6ddf
new version and fixed path
Jan 28, 2024
51d664b
undoing src and main
Jan 28, 2024
8979ae1
Merge remote-tracking branch 'origin/cuda-kernel' into transformers_cuda
Jan 28, 2024
4f31eae
cuda kernel
Jan 28, 2024
c9cf936
cuda kernel integration
Jan 28, 2024
d0f6ed4
removed cpp kernel
Jan 28, 2024
6cc2756
removed src changes
Jan 28, 2024
7476833
rmd testing notebook
Jan 28, 2024
eb8c2cd
dev3
Jan 28, 2024
1db5115
include nonpython files
Jan 28, 2024
7f7e853
benchmarks (temp)
Jan 29, 2024
5b3a5d2
test update
Jan 29, 2024
1804499
Some fixes and added 2x8 kernel
efrantar Jan 29, 2024
07f72b6
Merge remote-tracking branch 'origin/cuda-kernel' into transformers
Jan 29, 2024
bf0880f
new kernels
Jan 29, 2024
d7c4561
kernel asserts fix
Jan 30, 2024
823db17
numba kernel
Jan 30, 2024
22a7994
cleaner benchmark
Jan 30, 2024
b906bfd
handling flash-attn
Jan 30, 2024
6a6ebd3
no cuda import
BlackSamorez Jan 30, 2024
3937640
numba kernel working
BlackSamorez Jan 30, 2024
c643fec
black isort
Jan 30, 2024
d67d119
newer matmul benchmark
BlackSamorez Jan 31, 2024
c31d532
Merge branch 'transformers' of github.com:Vahe1994/AQLM into transfor…
BlackSamorez Jan 31, 2024
2d0cae8
fixed transposes
Jan 31, 2024
3deeab2
updated benchmarks
BlackSamorez Jan 31, 2024
aca05dd
removed extra benchmarks
Feb 5, 2024
cfa5e4a
less diff
Feb 5, 2024
9498bf3
benchmarks
Feb 5, 2024
7c6d234
Merge branch 'transformers' of github.com:Vahe1994/AQLM into transfor…
Feb 5, 2024
426a7b6
numba parallel and style
Feb 6, 2024
78cc9a8
cuda moved
Feb 6, 2024
1278164
moved cuda kernel
Feb 6, 2024
2dbd188
moved numba kernel
Feb 6, 2024
935347e
removed unnecessary functions
Feb 6, 2024
7b8faf8
dev7
Feb 6, 2024
33b0464
updated manifest
Feb 6, 2024
ead1c00
dev9
Feb 6, 2024
88d9a93
Update transformers/llama/modeling_llama_aqlm.py
BlackSamorez Feb 6, 2024
28d70f8
Update benchmark/generate_benchmark.py
BlackSamorez Feb 6, 2024
b31a3fc
Update benchmark/generate_benchmark.py
BlackSamorez Feb 6, 2024
d9f6b25
Update inference_lib/setup.cfg
BlackSamorez Feb 6, 2024
503ff40
correct authors
Feb 6, 2024
26ff8b0
cpp 1x16
Feb 6, 2024
09a7810
2x8 matmat cpp
Feb 6, 2024
c434d42
dev10
Feb 6, 2024
788c289
colab example
Feb 6, 2024
9fdf0a6
black
Feb 6, 2024
f2ef38b
colab example notebook
Feb 6, 2024
7342655
dev11 fix from Elias
Feb 6, 2024
989d5d8
dev12 __CUDA_ARCH__
Feb 6, 2024
5d4f4f3
much stuff
Feb 7, 2024
2a32c0a
readme, demo, req
Feb 7, 2024
f019b4e
more readme
Feb 7, 2024
d90c43b
dtype asserts
Feb 7, 2024
e06a789
black
Feb 7, 2024
098363a
installation
Feb 7, 2024
d7b6dfa
1.0.0
Feb 7, 2024
4bd67b9
1.0.0 for colab
Feb 7, 2024
d44c29d
deleted output
Feb 7, 2024
79706d0
mistral and mixtral
Feb 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
new kernels
  • Loading branch information
Andrei Panferov committed Jan 29, 2024
commit bf0880f281a610819c54ae002bca20b9b3ab107a
2 changes: 1 addition & 1 deletion inference_lib/src/aqlm/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .cuda_kernel import cuda_matmul
from .cuda_kernel import cuda_gemm_1x16, cuda_gemm_2x8
61 changes: 34 additions & 27 deletions inference_lib/src/aqlm/cuda/cuda_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,25 @@
)


def cuda_gemm_stupid(
input: torch.Tensor, # [num_inputs, in_features]
def cuda_gemm_1x16(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
input_shape = input.shape
input = input.reshape(-1, input_shape[-1])

device, dtype = codebooks.device, codebooks.dtype
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
in_features = input.shape[1]
out_features = codes.shape[0] * out_group_size
num_input_groups = codes.shape[1]
assert input.ndim == 2
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebooks.shape[1] == 2**16
assert codebook_size == 2**16
assert num_codebooks == 1
assert codes.dtype == torch.int16
assert input.dtype == torch.float16 and codebooks.dtype == torch.float16

Expand All @@ -38,33 +41,37 @@ def cuda_gemm_stupid(
output *= scales.flatten().unsqueeze(0)
if bias is not None:
output += bias
return output
return output.reshape(input_shape[:-1] + (-1,))

# codebook = torch.randn((codebook_size, in_group_size), dtype=torch.half, device=DEV)
# A = torch.randint(codebook_size, (out_features, in_features // in_group_size), dtype=torch.int, device=DEV)
# A_ref = torch.vstack([codebook[A[i]].flatten().unsqueeze(0) for i in range(M)])
# A = A.to(torch.int16)
# B = torch.randn((in_features, 1), dtype=torch.half, device=DEV)
# C = torch.zeros((out_features, 1), dtype=torch.half, device=DEV)

# C_ref = torch.matmul(A_ref, B)
# codebook_cuda.code16_matvec(A, B, C, codebook)


def cuda_matmul(
input: torch.Tensor,
codes: torch.IntTensor,
codebooks: torch.Tensor,
scales: torch.Tensor,
def cuda_gemm_2x8(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
input_shape = input.shape
input = input.reshape(-1, input_shape[-1])

return cuda_gemm_stupid(
input,
codes,
codebooks,
scales,
bias,
).reshape(input_shape[:-1] + (-1,))
device, dtype = codebooks.device, codebooks.dtype
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
in_features = input.shape[1]
out_features = codes.shape[0] * out_group_size
assert input.ndim == 2
assert scales.shape == (out_features // out_group_size, 1, 1, 1)
assert in_features % in_group_size == 0
assert codebook_size == 2**8
assert num_codebooks == 1
assert codes.dtype == torch.int8
assert input.dtype == torch.float16 and codebooks.dtype == torch.float16

output = torch.zeros(input.shape[0], out_features, device=device, dtype=dtype)
for i in range(input.shape[0]):
CUDA_KERNEL.code2x8_matvec(
codes.squeeze(2), input[i].unsqueeze(-1), output[i].unsqueeze(-1), codebooks.squeeze(0, 2)
)
output *= scales.flatten().unsqueeze(0)
if bias is not None:
output += bias
return output.reshape(input_shape[:-1] + (-1,))
32 changes: 18 additions & 14 deletions inference_lib/src/aqlm/inference_kernels/kernel_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ def forward_pass_quantized_linear(
bias: Optional[torch.Tensor],
) -> torch.Tensor:
num_codebooks, codebook_size, out_group_size, in_group_size = codebooks.shape
if cuda_kernel_applicable(input.is_cuda, num_codebooks, codebook_size, out_group_size, in_group_size):
from aqlm.cuda.cuda_kernel import cuda_matmul

return cuda_matmul(input, codes, codebooks, scales, bias)

if triton_kernel_applicable(input.is_cuda):
return triton_matmul(input, codes, codebooks, scales, bias)

dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
match (input.is_cuda, num_codebooks, codebook_size, out_group_size, in_group_size):
case (True, 1, 65536, 1, 8):
from aqlm.cuda.cuda_kernel import cuda_gemm_1x16

return cuda_gemm_1x16(input, codes, codebooks, scales, bias)
case (True, 2, 256, 1, 8):
from aqlm.cuda.cuda_kernel import cuda_gemm_2x8

return cuda_gemm_2x8(input, codes, codebooks, scales, bias)
case (True, _, _, _, _):
return triton_matmul(input, codes, codebooks, scales, bias)
case _:
dequantized_weight = _dequantize_weight(
unpack_int_data(codes, codebooks.shape[0].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)


def cuda_kernel_applicable(
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.nn as nn
from tqdm import trange
from tqdm.auto import trange
from transformers import PreTrainedModel

from aq_engine import AQEngine
from src.aq import QuantizedLinear
Expand All @@ -24,6 +23,7 @@
get_sequential_groups,
)
from src.utils import using_tf32
from transformers import PreTrainedModel

try:
import wandb
Expand Down
Loading