Skip to content

Commit 3437e3a

Browse files
authored
change the method of detecting linear (#849)
1 parent 70233bc commit 3437e3a

File tree

20 files changed

+73
-75
lines changed

20 files changed

+73
-75
lines changed

auto_round/compressors/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def _check_compatibility(self) -> None:
766766
and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq"))
767767
):
768768
for n, m in self.model.named_modules():
769-
if isinstance(m, self.supported_types):
769+
if type(m) in self.supported_types:
770770
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
771771
self.layer_config[n] = {"bits": 16}
772772
logger.info(
@@ -1991,7 +1991,7 @@ def _set_layerwise_config(self, layer_config: dict) -> bool:
19911991
is_gguf = hasattr(self, "formats") and any("gguf" in format_ for format_ in self.formats)
19921992
for n, m in self.model.named_modules():
19931993
# Skip unsupported types
1994-
if not isinstance(m, supported_types) and m.__class__.__name__ not in self.inner_supported_types:
1994+
if type(m) not in supported_types and m.__class__.__name__ not in self.inner_supported_types:
19951995
if n in self.layer_config:
19961996
if not isinstance(m, torch.nn.Embedding):
19971997
logger.warning(f"{n} is not supported, layer_config {n}: {layer_config[n]} will be ignored.")
@@ -2495,7 +2495,7 @@ def _replace_forward(self):
24952495
from functools import partial
24962496

24972497
for n, m in self.model.named_modules():
2498-
if n in self.to_cached_layers and not isinstance(m, tuple(self.supported_types)): ##block
2498+
if n in self.to_cached_layers and type(m) not in self.supported_types: ##block
24992499
m.orig_forward = m.forward
25002500
m.forward = partial(self._get_block_forward_func(n), m)
25012501
elif n in self.to_cached_layers: ##linear layer or conv1d layer
@@ -3219,7 +3219,7 @@ def _get_quantized_layer_names_outside_blocks(self) -> list:
32193219
if layer is None:
32203220
logger.error(f"could not find layer {key} in the model, exit...")
32213221
exit(-1)
3222-
if isinstance(layer, tuple(self.supported_types)) and check_to_quantized(self.layer_config[key]):
3222+
if type(layer) in self.supported_types and check_to_quantized(self.layer_config[key]):
32233223
layer_names.append(key)
32243224

32253225
return layer_names

auto_round/export/export_to_autogptq/export.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def pack_layer(name, model, backend, device=None):
7373
return
7474
layer = get_module(model, name)
7575

76-
if not isinstance(layer, SUPPORTED_LAYER_TYPES): # already packed
76+
if type(layer) not in SUPPORTED_LAYER_TYPES: # already packed
7777
return
7878

7979
orig_device = layer.weight.device # must place after 74
@@ -86,13 +86,13 @@ def pack_layer(name, model, backend, device=None):
8686

8787
QuantLinear = get_autogptq_packing_qlinear(backend, bits, group_size, sym)
8888

89-
if isinstance(layer, nn.Linear):
89+
if type(layer) == nn.Linear:
9090
in_features = layer.in_features
9191
out_features = layer.out_features
92-
elif isinstance(layer, nn.Conv2d):
92+
elif type(layer) == nn.Conv2d:
9393
in_features = layer.in_channels
9494
out_features = layer.out_channels
95-
elif isinstance(layer, transformers.pytorch_utils.Conv1D):
95+
elif type(layer) == transformers.pytorch_utils.Conv1D:
9696
in_features = layer.weight.shape[0]
9797
out_features = layer.weight.shape[1]
9898

auto_round/export/export_to_autogptq/qlinear_triton.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def pack(self, linear, scales, zeros, g_idx=None, device=None):
8585
self.scales = scales_t.clone().half()
8686

8787
W = linear.weight.data.to(device).clone()
88-
if isinstance(linear, nn.Conv2d):
88+
if type(linear) == nn.Conv2d:
8989
W = W.flatten(1)
90-
if isinstance(linear, transformers.pytorch_utils.Conv1D):
90+
if type(linear) == transformers.pytorch_utils.Conv1D:
9191
W = W.t()
9292

9393
repeat_scales = scales.to(device).repeat_interleave(self.group_size, 1)

auto_round/export/export_to_autoround/export.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ def pack_qact_layer(name, model):
118118

119119
QuantLinear = auto_round.export.export_to_autoround.qlinear_triton_act.QuantLinear
120120

121-
if isinstance(layer, nn.Linear):
121+
if type(layer) == nn.Linear:
122122
in_features = layer.in_features
123123
out_features = layer.out_features
124-
elif isinstance(layer, nn.Conv2d):
124+
elif type(layer) == nn.Conv2d:
125125
in_features = layer.in_channels
126126
out_features = layer.out_channels
127-
elif isinstance(layer, transformers.pytorch_utils.Conv1D):
127+
elif type(layer) == transformers.pytorch_utils.Conv1D:
128128
in_features = layer.weight.shape[0]
129129
out_features = layer.weight.shape[1]
130130
bias = layer.bias is not None
@@ -181,7 +181,7 @@ def pack_layer(layer_name, model, backend, device=None):
181181
if hasattr(layer, "orig_layer"):
182182
layer = layer.orig_layer
183183

184-
if not isinstance(layer, SUPPORTED_LAYER_TYPES): ##already packed
184+
if type(layer) not in SUPPORTED_LAYER_TYPES: ##already packed
185185
return
186186

187187
if int(layer.act_bits) <= 8:
@@ -200,13 +200,13 @@ def pack_layer(layer_name, model, backend, device=None):
200200
zp = layer.zp
201201
QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym, act_bits)
202202

203-
if isinstance(layer, nn.Linear):
203+
if type(layer) == nn.Linear:
204204
in_features = layer.in_features
205205
out_features = layer.out_features
206-
elif isinstance(layer, nn.Conv2d):
206+
elif type(layer) == nn.Conv2d:
207207
in_features = layer.in_channels
208208
out_features = layer.out_channels
209-
elif isinstance(layer, transformers.pytorch_utils.Conv1D):
209+
elif type(layer) == transformers.pytorch_utils.Conv1D:
210210
in_features = layer.weight.shape[0]
211211
out_features = layer.weight.shape[1]
212212
bias = layer.bias is not None

auto_round/export/export_to_autoround/export_to_fp8.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def pack_layer(layer_name, model, data_type, device=None):
9292
if hasattr(layer, "orig_layer"):
9393
layer = layer.orig_layer
9494

95-
if not isinstance(layer, SUPPORTED_LAYER_TYPES): ##already packed
95+
if type(layer) not in SUPPORTED_LAYER_TYPES: ##already packed
9696
return
9797

9898
if not check_to_quantized(layer):
@@ -119,13 +119,13 @@ def pack_layer(layer_name, model, data_type, device=None):
119119
q_weight = revert_tensor_by_pad(q_weight, orig_shape=orig_shape, pad_len=pad_len)
120120
q_weight = torch.clamp(q_weight, info.min, info.max)
121121
q_weight = q_weight.to(torch_dtype)
122-
if isinstance(layer, torch.nn.Linear):
122+
if type(layer) == torch.nn.Linear:
123123
in_features = layer.in_features
124124
out_features = layer.out_features
125125
# elif isinstance(layer, nn.Conv2d):
126126
# in_features = layer.in_channels
127127
# out_features = layer.out_channels
128-
elif isinstance(layer, transformers.pytorch_utils.Conv1D):
128+
elif type(layer) == transformers.pytorch_utils.Conv1D:
129129
in_features = layer.weight.shape[0]
130130
out_features = layer.weight.shape[1]
131131
bias = layer.bias

auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def pack_layer(name, model, backend, device=None):
5454
if name == "lm_head": # TODO: Check vLLM inference status to determine whether to enable this feature
5555
return
5656
layer = get_module(model, name)
57-
if not isinstance(layer, SUPPORTED_LAYER_TYPES) and not isinstance(layer, WrapperWALayer): ##already packed
57+
if type(layer) not in SUPPORTED_LAYER_TYPES and not isinstance(layer, WrapperWALayer): ##already packed
5858
return
5959

6060
if isinstance(layer, WrapperWALayer): # revert WrapperWALayer for offline usage
@@ -83,13 +83,13 @@ def pack_layer(name, model, backend, device=None):
8383

8484
# QuantLinear = get_fp_qlinear(backend, bits, group_size, sym)
8585

86-
if isinstance(layer, nn.Linear):
86+
if type(layer) == nn.Linear:
8787
in_features = layer.in_features
8888
out_features = layer.out_features
89-
elif isinstance(layer, nn.Conv2d):
89+
elif type(layer) == nn.Conv2d:
9090
in_features = layer.in_channels
9191
out_features = layer.out_channels
92-
elif isinstance(layer, transformers.pytorch_utils.Conv1D):
92+
elif type(layer) == transformers.pytorch_utils.Conv1D:
9393
in_features = layer.weight.shape[0]
9494
out_features = layer.weight.shape[1]
9595

@@ -172,7 +172,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs):
172172
if is_nv_fp(act_data_type) and "static_gs" in str(act_data_type).lower():
173173
# generate static input_global_scale
174174
for n, m in model.named_modules():
175-
if isinstance(m, SUPPORTED_LAYER_TYPES):
175+
if type(m) in SUPPORTED_LAYER_TYPES:
176176
layer = m
177177
if layer.act_bits < 8 and not getattr(layer, "input_global_scale", None):
178178
assert hasattr(layer, "act_max")

auto_round/export/export_to_autoround/qlinear_fp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ def pack(self, linear, scales, zeros=None, g_idx=None, global_scale=None, input_
136136
self.bias = linear.bias.detach().to(torch.float16)
137137

138138
W = linear.weight.data.detach().to(device)
139-
if isinstance(linear, nn.Conv2d):
139+
if type(linear) == nn.Conv2d:
140140
W = W.flatten(1)
141-
if isinstance(linear, transformers.pytorch_utils.Conv1D):
141+
if type(linear) == transformers.pytorch_utils.Conv1D:
142142
W = W.t()
143143

144144
tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(W, self.group_size)

auto_round/export/export_to_autoround/qlinear_triton_act.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,9 @@ def pack(self, linear, scales, zeros, act_scales, w_bf16_to_fp8_scale, g_idx=Non
129129
self.scales = scales_t.clone().half()
130130

131131
W = linear.weight.data.to(device).clone()
132-
if isinstance(linear, nn.Conv2d):
132+
if type(linear) == nn.Conv2d:
133133
W = W.flatten(1)
134-
if isinstance(linear, transformers.pytorch_utils.Conv1D):
134+
if type(linear) == transformers.pytorch_utils.Conv1D:
135135
W = W.t()
136136

137137
repeat_scales = scales.to(device).repeat_interleave(self.group_size, 1)

auto_round/export/export_to_awq/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def pack_layer(name, model, backend, device=None):
5050
return
5151
layer = get_module(model, name)
5252

53-
if not isinstance(layer, SUPPORTED_LAYER_TYPES): ##already packed
53+
if type(layer) not in SUPPORTED_LAYER_TYPES: ##already packed
5454
return
5555

5656
bits = layer.bits

auto_round/export/export_to_itrex/export.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,13 @@ def pack_model(
227227
else:
228228
scale = scale.to(dtype=convert_dtype)
229229
zp = zp.to(dtype=torch.int32) if isinstance(zp, torch.Tensor) else zp
230-
if isinstance(m, transformers.pytorch_utils.Conv1D):
230+
if type(m) == transformers.pytorch_utils.Conv1D:
231231
fp_weight = fp_weight.t_().contiguous()
232232
int_weight = quant_weight_w_scale(fp_weight, scale, zp, group_size, fp_weight.device)
233-
if isinstance(m, torch.nn.Linear):
233+
if type(m) == torch.nn.Linear:
234234
in_features = m.in_features
235235
out_features = m.out_features
236-
elif isinstance(m, transformers.pytorch_utils.Conv1D):
236+
elif type(m) == transformers.pytorch_utils.Conv1D:
237237
in_features = m.weight.shape[0]
238238
out_features = m.weight.shape[1]
239239
int_weight = int_weight.type(torch.int32)

0 commit comments

Comments
 (0)