Skip to content

Commit

Permalink
Deactivating v2 for sharded.
Browse files Browse the repository at this point in the history
It fails with illegal access on cuda when using sharding.
Took a long while to try and fix it:

All tensors are correct (same as v1).
Scratch size doesn't help
Error only occurs for sequence lengths > 50 (so during warmup most of
the time)

Couldn't figure out why this 50 particular number, nor change anything
to fix the behavior.
  • Loading branch information
Narsil committed Nov 25, 2023
1 parent a62c567 commit c766426
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
14 changes: 13 additions & 1 deletion server/text_generation_server/utils/gptq/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
w["scales"] = w["scales"].half()

# GPTQ with g_idx (act_order)
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
if w.get("g_idx", None) is not None and not (w["g_idx"] == 0).all().item():
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
w["q_invperm"] = torch.empty_like(w["q_perm"])
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
Expand Down Expand Up @@ -113,12 +113,24 @@ def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
self.maxq = 2 ** self.bits - 1
self.infeatures = qweight.shape[0] // self.bits * 32
self.outfeatures = qweight.shape[1]
self.padding = - self.outfeatures % 32
self.outfeatures = self.outfeatures + self.padding

self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx
self.bias = bias if bias is not None else None
self.group_size = groupsize

infeatures = self.infeatures
outfeatures = self.outfeatures
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
assert infeatures % self.group_size == 0
assert qzeros.shape == (infeatures // self.group_size, outfeatures // 32 * self.bits)
assert scales.shape == (infeatures // self.group_size, outfeatures)
assert g_idx.shape == (infeatures, ), f"{g_idx.shape}, {infeatures}"

global FIXED_BYTES, LAYERS
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
Expand Down
4 changes: 4 additions & 0 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
HAS_EXLLAMA = False
CAN_EXLLAMA = major >= 8
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
if V2 and int(os.getenv("WORLD_SIZE", "1")) > 1:
logger.warning("Disabling exllama v2 and using v1 instead because there are issues when sharding")
V2 = False

if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
Expand Down
18 changes: 4 additions & 14 deletions server/text_generation_server/utils/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,20 +281,10 @@ def get_multi_weights_row(self, prefix: str, quantize: str):
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")

if use_exllama:
if groupsize >= 0:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")

# For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = self.get_sharded(f"{prefix}.g_idx", dim= 0)
g_idx = g_idx - g_idx[0]
else:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
Expand Down

0 comments on commit c766426

Please sign in to comment.