Skip to content

Commit

Permalink
black, isort
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrei Panferov committed Jan 16, 2024
1 parent 89d3faa commit f498eaf
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 95 deletions.
40 changes: 19 additions & 21 deletions convert_to_hf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import re
import os
import json
import os
import re

import torch

from transformers import AutoConfig, PretrainedConfig

from tqdm.auto import trange
from transformers import AutoConfig, PretrainedConfig


def get_num_layers(config) -> int:
Expand All @@ -15,17 +13,17 @@ def get_num_layers(config) -> int:
return config["num_hidden_layers"]
case unknown_type:
raise NotImplementedError(f"Can't get number of layers for {unknown_type}")


def get_layers_prefix(config) -> str:
match config["model_type"]:
case "llama":
return "model.layers"
case unknown_type:
raise NotImplementedError(f"Can't get layers prefix for {unknown_type}")
def pack_ints(data: torch.IntTensor, nbits: int) -> torch.IntTensor:


def pack_ints(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
match nbits:
case x if x <= 8:
return data.to(torch.uint8)
Expand All @@ -44,7 +42,7 @@ def pack_ints(data: torch.IntTensor, nbits: int) -> torch.IntTensor:

def get_converted_state_dict(config, nbits: int, in_path: os.PathLike) -> dict:
state_dict = {}

num_layers = get_num_layers(config)
layers_prefix = get_layers_prefix(config)

Expand All @@ -58,19 +56,19 @@ def get_converted_state_dict(config, nbits: int, in_path: os.PathLike) -> dict:

for key, value in torch.load(os.path.join(in_path, "not_quantized_weights.pt")).items():
state_dict[key] = value

return state_dict


def get_metadata(in_path: os.PathLike) -> dict:
quant_args = torch.load(os.path.join(in_path, "args.pt"))
return {
'nbits_per_codebook': quant_args['nbits_per_codebook'],
'num_codebooks': quant_args['num_codebooks'],
'out_group_size': quant_args['out_group_size'],
'in_group_size': quant_args['in_group_size'],
"nbits_per_codebook": quant_args["nbits_per_codebook"],
"num_codebooks": quant_args["num_codebooks"],
"out_group_size": quant_args["out_group_size"],
"in_group_size": quant_args["in_group_size"],
}


if __name__ == "__main__":
import argparse
Expand All @@ -93,12 +91,12 @@ def get_metadata(in_path: os.PathLike) -> dict:
help="Path to save HF compatible checkpoint to",
)
args = parser.parse_args()

config, _ = PretrainedConfig.get_config_dict(args.model)
metadata = get_metadata(args.in_path)
config["aqlm"] = metadata
with open(os.path.join(args.out_path, "config.json"), "w") as config_file:
json.dump(config, config_file)
state_dict = get_converted_state_dict(config, metadata['nbits_per_codebook'], args.in_path)
torch.save(state_dict, os.path.join(args.out_path, "pytorch_model.bin"))
json.dump(config, config_file)

state_dict = get_converted_state_dict(config, metadata["nbits_per_codebook"], args.in_path)
torch.save(state_dict, os.path.join(args.out_path, "pytorch_model.bin"))
30 changes: 19 additions & 11 deletions modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import LlamaConfig
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
Expand All @@ -38,11 +42,9 @@
logging,
replace_return_docstrings,
)
from transformers import LlamaConfig

from src.aq import QuantizedLinear


if is_flash_attn_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
Expand Down Expand Up @@ -235,9 +237,7 @@ def forward(self, x):
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)

gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)

intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
Expand Down Expand Up @@ -282,10 +282,18 @@ def __init__(self, config: LlamaConfig):
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = QuantizedLinear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm)
self.k_proj = QuantizedLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm)
self.v_proj = QuantizedLinear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm)
self.o_proj = QuantizedLinear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm)
self.q_proj = QuantizedLinear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
)
self.k_proj = QuantizedLinear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
)
self.v_proj = QuantizedLinear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, **config.aqlm
)
self.o_proj = QuantizedLinear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias, **config.aqlm
)
self._init_rope()

def _init_rope(self):
Expand Down
62 changes: 46 additions & 16 deletions src/aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,30 @@
from tqdm.auto import trange

from src.kmeans import find_nearest_cluster, fit_faiss_kmeans, fit_kmeans, fit_kmeans_1d
from src.utils import ellipsis, maybe_script, get_int_dtype
from src.matmul_kernels import aqlm_gemv_simple
from src.utils import ellipsis, get_int_dtype, maybe_script


class QuantizedLinear(nn.Module):
EPS = 1e-9

def __init__(self, in_features: int, out_features: int , in_group_size: int, out_group_size: int, num_codebooks: int, nbits_per_codebook: int, bias=True, device=None, dtype=None):
factory_kwargs = {'device': device, 'dtype': dtype}

def __init__(
self,
in_features: int,
out_features: int,
in_group_size: int,
out_group_size: int,
num_codebooks: int,
nbits_per_codebook: int,
bias=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features

assert self.in_features % in_group_size == 0
assert self.out_features % out_group_size == 0
num_out_groups = out_features // out_group_size
Expand All @@ -29,23 +40,36 @@ def __init__(self, in_features: int, out_features: int , in_group_size: int, out
self.num_codebooks = num_codebooks
self.nbits_per_codebook = nbits_per_codebook
self.codebook_size = 2**nbits_per_codebook

self.codebooks = nn.Parameter(
torch.empty((num_codebooks, self.codebook_size, out_group_size, in_group_size), **factory_kwargs), requires_grad=True
torch.empty((num_codebooks, self.codebook_size, out_group_size, in_group_size), **factory_kwargs),
requires_grad=True,
) # [num_codebooks, codebook_size, out_group_size, in_group_size]
self.codes = nn.Parameter(torch.empty((num_out_groups, num_in_groups, num_codebooks), device=device, dtype=get_int_dtype(nbits_per_codebook)), requires_grad=False) # [num_out_groups, num_in_groups, num_codebooks]
self.scales = nn.Parameter(torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True) # [num_out_groups, 1, 1, 1]

self.codes = nn.Parameter(
torch.empty(
(num_out_groups, num_in_groups, num_codebooks), device=device, dtype=get_int_dtype(nbits_per_codebook)
),
requires_grad=False,
) # [num_out_groups, num_in_groups, num_codebooks]
self.scales = nn.Parameter(
torch.empty((num_out_groups, 1, 1, 1), **factory_kwargs), requires_grad=True
) # [num_out_groups, 1, 1, 1]

if bias:
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter('bias', None)
self.register_parameter("bias", None)

def forward(self, input: torch.Tensor) -> torch.Tensor:
# return F.linear(input, self.reconstruct_weight(), self.bias)
original_shape = input.shape
input = input.reshape(-1, original_shape[-1])
return torch.cat([aqlm_gemv_simple(input_vector.unsqueeze(0), self.codes, self.codebooks, self.scales) for input_vector in input]).reshape(original_shape[:-1] + (-1,))
return torch.cat(
[
aqlm_gemv_simple(input_vector.unsqueeze(0), self.codes, self.codebooks, self.scales)
for input_vector in input
]
).reshape(original_shape[:-1] + (-1,))

def initialize(
self,
Expand All @@ -56,7 +80,10 @@ def initialize(
assert reference_weight.shape == (self.out_features, self.in_features)
with torch.no_grad():
weight_groupwise = reference_weight.reshape(
self.out_features // self.out_group_size, self.out_group_size, self.in_features // self.in_group_size, self.in_group_size
self.out_features // self.out_group_size,
self.out_group_size,
self.in_features // self.in_group_size,
self.in_group_size,
).swapaxes(
1, 2
) # [num_out_groups, num_in_groups, out_group_size, in_group_size]
Expand All @@ -77,7 +104,6 @@ def initialize(
self.codes.data = codes
self.codebooks.data = codebooks


def get_codebooks(self) -> torch.Tensor:
"""Get quantization codebooks or reconstruct them from second level quantization (see codebook_values_nbits)"""
return self.codebooks
Expand All @@ -95,7 +121,11 @@ def reconstruct_weight(self, selection: Union[slice, ellipsis, torch.Tensor] = .
Formally, the indices must be in range [ 0 , self.out_features // self.out_group_size )
"""
weight = _dequantize_weight(self.codes[selection].to(torch.int64) % (2 ** self.nbits_per_codebook), self.get_codebooks(), self.get_scales()[selection])
weight = _dequantize_weight(
self.codes[selection].to(torch.int64) % (2**self.nbits_per_codebook),
self.get_codebooks(),
self.get_scales()[selection],
)
return weight

@torch.no_grad()
Expand Down
Loading

0 comments on commit f498eaf

Please sign in to comment.