-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Core] Set linear_weights
directly on the layer
#3977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fee1dc4
3b46ae5
dcadfbd
09ef14c
888ba9c
46c2f29
a740d2b
e20cdc1
59b4fb6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Dict, List, Optional | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
@@ -29,19 +29,24 @@ class LinearMethodBase(ABC): | |
"""Base class for different (maybe quantized) linear methods.""" | ||
|
||
@abstractmethod | ||
def create_weights(self, input_size_per_partition: int, | ||
def create_weights(self, layer: torch.nn.Module, | ||
input_size_per_partition: int, | ||
output_size_per_partition: int, input_size: int, | ||
output_size: int, | ||
params_dtype: torch.dtype) -> Dict[str, Any]: | ||
"""Create weights for a linear layer.""" | ||
output_size: int, params_dtype: torch.dtype, | ||
**extra_weight_attrs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this just be an optional I see the value of enabling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, making it a kwarg will be easier as we don't need extra handling to not set the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. :) |
||
"""Create weights for a linear layer. | ||
|
||
The weights will be set as attributes of the layer.""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def apply_weights(self, | ||
weights: Dict[str, torch.Tensor], | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
"""Apply the weights to the input tensor.""" | ||
"""Apply the weights in layer to the input tensor. | ||
|
||
Expects create_weights to have been called before on the layer.""" | ||
raise NotImplementedError | ||
|
||
|
||
|
@@ -56,22 +61,24 @@ class UnquantizedLinearMethod(LinearMethodBase): | |
def __init__(self, separate_bias_add: bool = False): | ||
self.separate_bias_add = separate_bias_add | ||
|
||
def create_weights(self, input_size_per_partition: int, | ||
def create_weights(self, layer: torch.nn.Module, | ||
input_size_per_partition: int, | ||
output_size_per_partition: int, input_size: int, | ||
output_size: int, | ||
params_dtype: torch.dtype) -> Dict[str, Any]: | ||
output_size: int, params_dtype: torch.dtype, | ||
**extra_weight_attrs): | ||
weight = Parameter(torch.empty(output_size_per_partition, | ||
input_size_per_partition, | ||
dtype=params_dtype), | ||
requires_grad=False) | ||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) | ||
return {"weight": weight} | ||
layer.register_parameter("weight", weight) | ||
set_weight_attrs(weight, extra_weight_attrs) | ||
|
||
def apply_weights(self, | ||
weights: Dict[str, torch.Tensor], | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
weight = weights["weight"] | ||
weight = layer.weight | ||
if self.separate_bias_add: | ||
if bias is not None: | ||
return F.linear(x, weight) + bias | ||
|
@@ -112,12 +119,9 @@ def __init__( | |
if linear_method is None: | ||
linear_method = UnquantizedLinearMethod() | ||
self.linear_method = linear_method | ||
self.linear_weights = self.linear_method.create_weights( | ||
self.input_size, self.output_size, self.input_size, | ||
self.output_size, self.params_dtype) | ||
for name, weight in self.linear_weights.items(): | ||
if isinstance(weight, torch.Tensor): | ||
self.register_parameter(name, weight) | ||
self.linear_method.create_weights(self, self.input_size, | ||
self.output_size, self.input_size, | ||
self.output_size, self.params_dtype) | ||
if bias: | ||
self.bias = Parameter( | ||
torch.empty(self.output_size, dtype=self.params_dtype)) | ||
|
@@ -127,7 +131,7 @@ def __init__( | |
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
bias = self.bias if not self.skip_bias_add else None | ||
output = self.linear_method.apply_weights(self.linear_weights, x, bias) | ||
output = self.linear_method.apply_weights(self, x, bias) | ||
output_bias = self.bias if self.skip_bias_add else None | ||
return output, output_bias | ||
|
||
|
@@ -178,13 +182,13 @@ def __init__( | |
if linear_method is None: | ||
linear_method = UnquantizedLinearMethod() | ||
self.linear_method = linear_method | ||
self.linear_weights = self.linear_method.create_weights( | ||
self.input_size, self.output_size_per_partition, self.input_size, | ||
self.output_size, self.params_dtype) | ||
for name, weight in self.linear_weights.items(): | ||
if isinstance(weight, torch.Tensor): | ||
self.register_parameter(name, weight) | ||
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) | ||
self.linear_method.create_weights(self, | ||
self.input_size, | ||
self.output_size_per_partition, | ||
self.input_size, | ||
self.output_size, | ||
self.params_dtype, | ||
weight_loader=self.weight_loader) | ||
if bias: | ||
self.bias = Parameter( | ||
torch.empty(self.output_size_per_partition, | ||
|
@@ -212,8 +216,7 @@ def forward(self, input_): | |
bias = self.bias if not self.skip_bias_add else None | ||
|
||
# Matrix multiply. | ||
output_parallel = self.linear_method.apply_weights( | ||
self.linear_weights, input_, bias) | ||
output_parallel = self.linear_method.apply_weights(self, input_, bias) | ||
if self.gather_output: | ||
# All-gather across the partitions. | ||
output = tensor_model_parallel_all_gather(output_parallel) | ||
|
@@ -524,13 +527,13 @@ def __init__( | |
if linear_method is None: | ||
linear_method = UnquantizedLinearMethod() | ||
self.linear_method = linear_method | ||
self.linear_weights = self.linear_method.create_weights( | ||
self.input_size_per_partition, self.output_size, self.input_size, | ||
self.output_size, self.params_dtype) | ||
for name, weight in self.linear_weights.items(): | ||
if isinstance(weight, torch.Tensor): | ||
self.register_parameter(name, weight) | ||
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) | ||
self.linear_method.create_weights(self, | ||
self.input_size_per_partition, | ||
self.output_size, | ||
self.input_size, | ||
self.output_size, | ||
self.params_dtype, | ||
weight_loader=self.weight_loader) | ||
|
||
if not reduce_results and (bias and not skip_bias_add): | ||
raise ValueError("When not reduce the results, adding bias to the " | ||
|
@@ -570,7 +573,7 @@ def forward(self, input_): | |
|
||
# Matrix multiply. | ||
output_parallel = self.linear_method.apply_weights( | ||
self.linear_weights, input_parallel) | ||
self, input_parallel) | ||
if self.reduce_results and self.tp_size > 1: | ||
output_ = tensor_model_parallel_all_reduce(output_parallel) | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,12 +89,14 @@ def __init__(self, quant_config: GPTQConfig): | |
|
||
def create_weights( | ||
self, | ||
layer: torch.nn.Module, | ||
input_size_per_partition: int, | ||
output_size_per_partition: int, | ||
input_size: int, | ||
output_size: int, | ||
params_dtype: torch.dtype, | ||
) -> Dict[str, Any]: | ||
**extra_weight_attrs, | ||
): | ||
del output_size # Unused. | ||
if input_size_per_partition % self.quant_config.group_size != 0: | ||
raise ValueError( | ||
|
@@ -179,37 +181,40 @@ def create_weights( | |
"input_dim": scale_and_zero_input_dim, | ||
"output_dim": 1, | ||
}) | ||
return { | ||
"qweight": qweight, | ||
"g_idx": g_idx, | ||
"qzeros": qzeros, | ||
"scales": scales, | ||
"exllama_state": exllama_state, | ||
} | ||
|
||
layer.register_parameter("qweight", qweight) | ||
set_weight_attrs(qweight, extra_weight_attrs) | ||
layer.register_parameter("g_idx", g_idx) | ||
set_weight_attrs(g_idx, extra_weight_attrs) | ||
layer.register_parameter("qzeros", qzeros) | ||
set_weight_attrs(qzeros, extra_weight_attrs) | ||
layer.register_parameter("scales", scales) | ||
set_weight_attrs(scales, extra_weight_attrs) | ||
|
||
layer.exllama_state = exllama_state | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The This will avoid having this dangling member There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah actually I don't think this is possible since GPTQLinearMethod is a singleton across all layers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good |
||
|
||
def apply_weights(self, | ||
weights: Dict[str, Any], | ||
layer: torch.nn.Module, | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | ||
qweight = weights["qweight"] | ||
qweight = layer.qweight | ||
out_shape = x.shape[:-1] + (qweight.shape[-1], ) | ||
reshaped_x = x.reshape(-1, x.shape[-1]) | ||
# exllama needs to shuffle the weight after the weight is loaded | ||
# here we do the shuffle on first forward pass | ||
if weights["exllama_state"] == ExllamaState.UNINITIALIZED: | ||
if layer.exllama_state == ExllamaState.UNINITIALIZED: | ||
if self.quant_config.desc_act: | ||
weights["g_idx"] = torch.argsort(weights["g_idx"]).to( | ||
torch.int) | ||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) | ||
else: | ||
weights["g_idx"] = torch.empty((1, 1), device="meta") | ||
weights["exllama_state"] = ExllamaState.READY | ||
ops.gptq_shuffle(weights["qweight"], weights["g_idx"], | ||
layer.g_idx.data = torch.empty((0, ), | ||
device=layer.g_idx.device) | ||
layer.exllama_state = ExllamaState.READY | ||
ops.gptq_shuffle(layer.qweight, layer.g_idx, | ||
self.quant_config.weight_bits) | ||
output = ops.gptq_gemm(reshaped_x, weights["qweight"], | ||
weights["qzeros"], weights["scales"], | ||
weights["g_idx"], | ||
weights["exllama_state"] == ExllamaState.READY, | ||
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, | ||
layer.scales, layer.g_idx, | ||
layer.exllama_state == ExllamaState.READY, | ||
self.quant_config.weight_bits) | ||
if bias is not None: | ||
output = output + bias | ||
output.add_(bias) | ||
return output.reshape(out_shape) |
Uh oh!
There was an error while loading. Please reload this page.