Skip to content

Commit 39b02de

Browse files
authored
refactor _replace_linear_8da4w (#451)
* refactor _replace_linear_8da4w * clean up version ---------
1 parent dee13e1 commit 39b02de

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

torchao/quantization/GPTQ.py

+17-20
Original file line numberDiff line numberDiff line change
@@ -893,10 +893,15 @@ def _replace_linear_8da4w(
893893
linear_class: Type[torch.nn.Module],
894894
copy_weights: bool = False,
895895
):
896-
for name, child in module.named_children():
897-
if isinstance(child, nn.Linear):
898-
if _check_linear_int4_k(child.in_features, groupsize) or padding_allowed:
899-
new_linear = linear_class(
896+
897+
#import the util function here to avoid circular dependency
898+
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
899+
900+
def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool:
901+
return isinstance(child, nn.Linear) and (_check_linear_int4_k(child.in_features, groupsize) or padding_allowed)
902+
903+
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
904+
new_linear = linear_class(
900905
child.in_features,
901906
child.out_features,
902907
bias=False,
@@ -905,22 +910,14 @@ def _replace_linear_8da4w(
905910
precision=precision,
906911
scales_precision=scales_precision,
907912
)
908-
# In distributed training, the model may be instantiated
909-
# on the meta device, in which case there is no need to
910-
# copy the weights, and doing so will result in an error
911-
if copy_weights and child.weight.device != torch.device("meta"):
912-
new_linear.weight = child.weight
913-
setattr(module, name, new_linear)
914-
else:
915-
_replace_linear_8da4w(
916-
child,
917-
groupsize,
918-
padding_allowed,
919-
precision,
920-
scales_precision,
921-
linear_class,
922-
copy_weights,
923-
)
913+
# In distributed training, the model may be instantiated
914+
# on the meta device, in which case there is no need to
915+
# copy the weights, and doing so will result in an error
916+
if copy_weights and child.weight.device != torch.device("meta"):
917+
new_linear.weight = child.weight
918+
return new_linear
919+
920+
_replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn)
924921

925922
def replace_linear_8da4w(
926923
module: torch.nn.Module,

0 commit comments

Comments
 (0)