-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformers save/load compatibility and inference kernels #3
Merged
Merged
Changes from all commits
Commits
Show all changes
111 commits
Select commit
Hold shift + click to select a range
e159c7f
Added 1x16 CUDA kernel
efrantar 9ec6b70
Conversion script
395a9f6
black, src, indent
8e80100
pack_int_data
4cabbc3
deprecated double quant, scales
2c97a9f
estimate_nbits_per_parameter
7864596
scales fix
3f66363
First triton kernel
07986d9
black, isort
56359db
Quantization refactoring started
6ad6a6b
restored double quant
46be31d
estimate_nbits_per_parameter
6ee4353
less diff aq_engine
9e917f1
bias processing
e281962
removed debug prints
8c35e89
additional kwargs in config
a795ee6
removed matmul kernel
10c0a7e
packing and unpacking integers
c41f09e
packs and unpacks
ed1430f
undoing
eb89781
FinalizedQuantizedLinear
96301f7
tied up and saving
c19a5f6
fixed saving
0013d8e
removed unsupported kwargs
89d6085
triton kernel again
444f788
bias in triton
0f5483a
renamed smt
a9e65bb
new configuration ides
27e3855
inference file copying
0f37ea6
separated saving
e3b623f
Fixed cloning
fe9a9f2
skernel
164c5d1
better saving
db0d5a2
isort
0d7e1af
removed unnecessary dependencies
9f83593
lm_eval tokenizer trust remote code
990d5b2
llama tokenizer
b0b59f5
Deleted llama tokenizers
fd8395f
faster triton kernel
e519193
has_bias tl constexpr
9e89f61
cpp_kernel benchmarks
e673e0b
better order
0d69cf6
better compile flags
59e1f3c
removed unnecessary pragmas
e98ed7c
fixed stuff
c245c5e
removed test function
BlackSamorez 147ed79
icpx
BlackSamorez a5ae331
inference_lib
BlackSamorez 11e5dfd
inference lib done
e5e9ee9
Correct modeling_llama.py
48d6ddf
new version and fixed path
51d664b
undoing src and main
8979ae1
Merge remote-tracking branch 'origin/cuda-kernel' into transformers_cuda
4f31eae
cuda kernel
c9cf936
cuda kernel integration
d0f6ed4
removed cpp kernel
6cc2756
removed src changes
7476833
rmd testing notebook
eb8c2cd
dev3
1db5115
include nonpython files
7f7e853
benchmarks (temp)
5b3a5d2
test update
1804499
Some fixes and added 2x8 kernel
efrantar 07f72b6
Merge remote-tracking branch 'origin/cuda-kernel' into transformers
bf0880f
new kernels
d7c4561
kernel asserts fix
823db17
numba kernel
22a7994
cleaner benchmark
b906bfd
handling flash-attn
6a6ebd3
no cuda import
BlackSamorez 3937640
numba kernel working
BlackSamorez c643fec
black isort
d67d119
newer matmul benchmark
BlackSamorez c31d532
Merge branch 'transformers' of github.com:Vahe1994/AQLM into transfor…
BlackSamorez 2d0cae8
fixed transposes
3deeab2
updated benchmarks
BlackSamorez aca05dd
removed extra benchmarks
cfa5e4a
less diff
9498bf3
benchmarks
7c6d234
Merge branch 'transformers' of github.com:Vahe1994/AQLM into transfor…
426a7b6
numba parallel and style
78cc9a8
cuda moved
1278164
moved cuda kernel
2dbd188
moved numba kernel
935347e
removed unnecessary functions
7b8faf8
dev7
33b0464
updated manifest
ead1c00
dev9
88d9a93
Update transformers/llama/modeling_llama_aqlm.py
BlackSamorez 28d70f8
Update benchmark/generate_benchmark.py
BlackSamorez b31a3fc
Update benchmark/generate_benchmark.py
BlackSamorez d9f6b25
Update inference_lib/setup.cfg
BlackSamorez 503ff40
correct authors
26ff8b0
cpp 1x16
09a7810
2x8 matmat cpp
c434d42
dev10
788c289
colab example
9fdf0a6
black
f2ef38b
colab example notebook
7342655
dev11 fix from Elias
989d5d8
dev12 __CUDA_ARCH__
5d4f4f3
much stuff
2a32c0a
readme, demo, req
f019b4e
more readme
d90c43b
dtype asserts
e06a789
black
098363a
installation
d7b6dfa
1.0.0
4bd67b9
1.0.0 for colab
d44c29d
deleted output
79706d0
mistral and mixtral
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import argparse | ||
import os | ||
|
||
os.environ["OMP_NUM_THREADS"] = "1" | ||
os.environ["MKL_NUM_THREADS"] = "1" | ||
import time | ||
import warnings | ||
|
||
warnings.filterwarnings("ignore") | ||
|
||
import torch | ||
|
||
torch.set_num_threads(8) | ||
from torch import nn | ||
|
||
from transformers import AutoConfig, AutoModelForCausalLM | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(add_help=True) | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--num_codebooks", | ||
type=int, | ||
default=None, | ||
) | ||
parser.add_argument( | ||
"--in_group_size", | ||
type=int, | ||
default=None, | ||
) | ||
parser.add_argument( | ||
"--nbits_per_codebook", | ||
type=int, | ||
default=None, | ||
) | ||
parser.add_argument( | ||
"--warmup_iters", | ||
type=int, | ||
default=1, | ||
help="Number of warmup iterations.", | ||
) | ||
parser.add_argument( | ||
"--benchmark_iters", | ||
type=int, | ||
default=3, | ||
help="Number of benchmark iterations.", | ||
) | ||
parser.add_argument( | ||
"--input_length", | ||
type=int, | ||
default=1, | ||
help="Input length.", | ||
) | ||
parser.add_argument( | ||
"--output_length", | ||
type=int, | ||
default=128, | ||
help="Output length.", | ||
) | ||
args = parser.parse_args() | ||
|
||
device = "cpu" | ||
|
||
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True, torch_dtype=torch.float32) | ||
if args.num_codebooks is not None: | ||
config.aqlm["num_codebooks"] = args.num_codebooks | ||
if args.in_group_size is not None: | ||
config.aqlm["in_group_size"] = args.in_group_size | ||
if args.nbits_per_codebook is not None: | ||
config.aqlm["nbits_per_codebook"] = args.nbits_per_codebook | ||
|
||
real_num_layers = config.num_hidden_layers | ||
if "meta-llama" in args.model: | ||
config.num_hidden_layers = 1 | ||
aqlm_model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, torch_dtype=torch.float32) | ||
|
||
if "meta-llama" in args.model: | ||
aqlm_model.config.num_hidden_layers = real_num_layers | ||
layer = aqlm_model.model.layers[0] | ||
aqlm_model.model.layers = nn.ModuleList([]) | ||
for i in range(real_num_layers): | ||
another_layer = type(layer)(config, i) | ||
|
||
another_layer.self_attn.q_proj.weight.data = layer.self_attn.q_proj.weight.data | ||
another_layer.self_attn.k_proj.weight.data = layer.self_attn.k_proj.weight.data | ||
another_layer.self_attn.v_proj.weight.data = layer.self_attn.v_proj.weight.data | ||
another_layer.self_attn.o_proj.weight.data = layer.self_attn.o_proj.weight.data | ||
another_layer.mlp.up_proj.weight.data = layer.mlp.up_proj.weight.data | ||
another_layer.mlp.down_proj.weight.data = layer.mlp.down_proj.weight.data | ||
another_layer.mlp.gate_proj.weight.data = layer.mlp.gate_proj.weight.data | ||
|
||
another_layer.self_attn.layer_idx = i | ||
aqlm_model.model.layers.append(another_layer) | ||
|
||
aqlm_model.model.config.num_hidden_layers = real_num_layers | ||
|
||
prompt = torch.randint(low=0, high=aqlm_model.config.vocab_size, size=(1, args.input_length), device=device) | ||
|
||
for i in range(args.warmup_iters + args.benchmark_iters): | ||
aqlm_model.generate(prompt, min_new_tokens=args.output_length, max_new_tokens=args.output_length) | ||
if i == args.warmup_iters - 1: | ||
t_s = time.perf_counter() | ||
t_e = time.perf_counter() | ||
|
||
tokens_per_second = args.benchmark_iters * args.output_length / (t_e - t_s) | ||
print(f"<Tokens per second> = {tokens_per_second:.3f}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import argparse | ||
import os | ||
import time | ||
import warnings | ||
|
||
warnings.filterwarnings("ignore") | ||
import torch | ||
import torch.nn as nn | ||
from tqdm import trange | ||
|
||
from transformers import AutoConfig, AutoModelForCausalLM | ||
|
||
if __name__ == "__main__": | ||
assert torch.cuda.is_available() | ||
device = torch.device("cuda") | ||
|
||
parser = argparse.ArgumentParser(add_help=True) | ||
|
||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--warmup_iters", | ||
type=int, | ||
default=1, | ||
help="Number of warmup iterations.", | ||
) | ||
parser.add_argument( | ||
"--benchmark_iters", | ||
type=int, | ||
default=10, | ||
help="Number of benchmark iterations.", | ||
) | ||
parser.add_argument( | ||
"--input_length", | ||
type=int, | ||
default=1, | ||
help="Input length.", | ||
) | ||
parser.add_argument( | ||
"--output_length", | ||
type=int, | ||
default=128, | ||
help="Output length.", | ||
) | ||
parser.add_argument( | ||
"--real_model", | ||
action="store_true", | ||
) | ||
parser.add_argument( | ||
"--low_cpu_mem_usage", | ||
action="store_true", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
|
||
def load_model(model_name, device="cuda"): | ||
return AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
trust_remote_code=True, | ||
torch_dtype="auto", | ||
).to(device) | ||
|
||
|
||
def load_shared_model(model_name, device="cuda"): | ||
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
num_layers = config.num_hidden_layers | ||
config.num_hidden_layers = 1 | ||
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, torch_dtype=torch.float16).to(device) | ||
layer = model.model.layers[0] | ||
for i in trange(1, num_layers, desc="Copying block parameters"): | ||
new_layer = type(layer)(model.config, i).to(device) | ||
for new_layer_param, layer_param in zip(new_layer.parameters(), layer.parameters()): | ||
new_layer_param.data = layer_param.data | ||
new_layer.self_attn.layer_idx = i | ||
model.model.layers.append(new_layer) | ||
return model | ||
|
||
|
||
if __name__ == "__main__": | ||
assert torch.cuda.is_available() | ||
device = torch.device("cuda") | ||
|
||
parser = argparse.ArgumentParser(add_help=True) | ||
|
||
config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) | ||
|
||
if args.real_model: | ||
aqlm_model = load_model(args.model, device) | ||
else: | ||
aqlm_model = load_shared_model(args.model, device) | ||
|
||
prompt = torch.randint(low=0, high=aqlm_model.config.vocab_size, size=(1, args.input_length), device=device) | ||
|
||
for i in range(args.warmup_iters + args.benchmark_iters): | ||
output = aqlm_model.generate(prompt, min_new_tokens=args.output_length, max_new_tokens=args.output_length) | ||
if i == args.warmup_iters - 1: | ||
torch.cuda.synchronize(device) | ||
t_s = time.perf_counter() | ||
torch.cuda.synchronize(device) | ||
t_e = time.perf_counter() | ||
|
||
tokens_per_second = args.benchmark_iters * args.output_length / (t_e - t_s) | ||
print(f"<Tokens per second> = {tokens_per_second:.2f}") |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't find an else statement. consider: else assert NotImplementedError(...)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more complicated. It's meant to differentiate between quantised and unquantized models like that.