Skip to content

Commit

Permalink
Some renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jun 12, 2024
1 parent 9b77526 commit f0187ef
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
26 changes: 13 additions & 13 deletions server/text_generation_server/layers/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_perms() -> Tuple[List[int], List[int]]:
_scale_perm, _scale_perm_single = _get_perms()


def permute_scales(scales: torch.Tensor, in_features: int, group_size: int):
def permute_scales(scales: torch.Tensor):
out_features = scales.shape[1]
if scales.shape[0] == 1:
scales = scales.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
Expand Down Expand Up @@ -95,8 +95,8 @@ def repack_gptq_for_marlin(
g_idx: torch.Tensor,
bits: int,
desc_act: bool,
group_size: int,
is_sym: bool,
groupsize: int,
sym: bool,
sharded_infeatures: bool,
) -> GPTQMarlinWeight:
"""Convert GPTQ weights to a layout that's compatible with GPTQ-Marlin kernels."""
Expand All @@ -109,12 +109,12 @@ def repack_gptq_for_marlin(
f"Repacking {bits}-bit GPTQ weights as Marlin is not supported, must be one of: {supported_bits}"
)

if group_size not in GPTQ_MARLIN_GROUP_SIZES:
if groupsize not in GPTQ_MARLIN_GROUP_SIZES:
supported_sizes = ", ".join(str(b) for b in GPTQ_MARLIN_GROUP_SIZES)
raise RuntimeError(
f"Repacking GPTQ weights with group size {group_size} as Marlin is not supported, must be one of: {supported_sizes}"
f"Repacking GPTQ weights with group size {groupsize} as Marlin is not supported, must be one of: {supported_sizes}"
)
if not is_sym:
if not sym:
raise RuntimeError(
"Repacking GPTQ weights with asymmetric quantization as Marlin is not supported."
)
Expand All @@ -123,12 +123,12 @@ def repack_gptq_for_marlin(
in_features = qweight.shape[0] * weights_per_int
out_features = qweight.shape[1]

if in_features % group_size != 0:
if in_features % groupsize != 0:
raise ValueError(
f"Number of input features ({in_features}) not divisible by group size ({group_size})"
f"Number of input features ({in_features}) not divisible by group size ({groupsize})"
)

if desc_act and group_size != -1:
if desc_act and groupsize != -1:
perm = torch.argsort(g_idx).to(torch.int)
g_idx = g_idx[perm]
else:
Expand All @@ -139,7 +139,7 @@ def repack_gptq_for_marlin(
qweight, perm, in_features, out_features, bits
)

scales = permute_scales(scales, in_features, group_size)
scales = permute_scales(scales)

is_full_k = not (desc_act and sharded_infeatures)

Expand Down Expand Up @@ -249,11 +249,11 @@ def __init__(self, *, weight: MarlinWeight, bias: Optional[torch.Tensor]):
out_features % 256 == 0
), f"Number of output features ({out_features}) not divisable by 256"

group_size = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert group_size in {
groupsize = -1 if weight.s.shape[0] == 1 else in_features // weight.s.shape[0]
assert groupsize in {
-1,
128,
}, f"Group size must be -1 or 128, was {group_size}"
}, f"Group size must be -1 or 128, was {groupsize}"

self.register_buffer("B", weight.B)
self.register_buffer("s", weight.s)
Expand Down
14 changes: 7 additions & 7 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def get_weights_col_packed(
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
group_size=gptq_params.groupsize,
is_sym=gptq_params.sym,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)

Expand Down Expand Up @@ -416,8 +416,8 @@ def get_multi_weights_col(self, prefixes: List[str], quantize: str, dim: int):
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
group_size=gptq_params.groupsize,
is_sym=gptq_params.sym,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=False,
)
else:
Expand Down Expand Up @@ -638,8 +638,8 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
g_idx=g_idx,
bits=gptq_params.bits,
desc_act=gptq_params.desc_act,
group_size=gptq_params.groupsize,
is_sym=gptq_params.sym,
groupsize=gptq_params.groupsize,
sym=gptq_params.sym,
sharded_infeatures=sharded_in_features,
)
else:
Expand All @@ -652,7 +652,7 @@ def get_multi_weights_row(self, prefix: str, quantize: str):

num_groups = self._get_slice(f"{prefix}.s").get_shape()[0]
if num_groups == 1:
# The number of groups is 1 when group_size == -1. share
# The number of groups is 1 when groupsize == -1. share
# scales between all shards in this case.
s = self.get_tensor(f"{prefix}.s")
else:
Expand Down

0 comments on commit f0187ef

Please sign in to comment.