Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Set linear_weights directly on the layer #3977

Merged
merged 9 commits into from
Apr 11, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Review comment
  • Loading branch information
Yard1 committed Apr 11, 2024
commit e20cdc1671830e0bdd295cc24de12868867e1c1e
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,15 @@ class ExllamaState(Enum):
class GPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ.

Note this linear method holds its own state.

Args:
quant_config: The GPTQ quantization config.
"""

def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
self.exllama_state = ExllamaState.UNINITIALIZED

def create_weights(
self,
Expand Down Expand Up @@ -191,7 +194,7 @@ def create_weights(
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)

layer.exllama_state = exllama_state
self.exllama_state = exllama_state

def apply_weights(self,
layer: torch.nn.Module,
Expand All @@ -202,18 +205,18 @@ def apply_weights(self,
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
self.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
output.add_(bias)
Expand Down
Loading