@@ -893,10 +893,15 @@ def _replace_linear_8da4w(
893
893
linear_class : Type [torch .nn .Module ],
894
894
copy_weights : bool = False ,
895
895
):
896
- for name , child in module .named_children ():
897
- if isinstance (child , nn .Linear ):
898
- if _check_linear_int4_k (child .in_features , groupsize ) or padding_allowed :
899
- new_linear = linear_class (
896
+
897
+ #import the util function here to avoid circular dependency
898
+ from torchao .quantization .quant_api import _replace_with_custom_fn_if_matches_filter
899
+
900
+ def filter_fn (child : torch .nn .Module , cur_fqn :str ) -> bool :
901
+ return isinstance (child , nn .Linear ) and (_check_linear_int4_k (child .in_features , groupsize ) or padding_allowed )
902
+
903
+ def replacement_fn (child : torch .nn .Module ) -> torch .nn .Module :
904
+ new_linear = linear_class (
900
905
child .in_features ,
901
906
child .out_features ,
902
907
bias = False ,
@@ -905,22 +910,14 @@ def _replace_linear_8da4w(
905
910
precision = precision ,
906
911
scales_precision = scales_precision ,
907
912
)
908
- # In distributed training, the model may be instantiated
909
- # on the meta device, in which case there is no need to
910
- # copy the weights, and doing so will result in an error
911
- if copy_weights and child .weight .device != torch .device ("meta" ):
912
- new_linear .weight = child .weight
913
- setattr (module , name , new_linear )
914
- else :
915
- _replace_linear_8da4w (
916
- child ,
917
- groupsize ,
918
- padding_allowed ,
919
- precision ,
920
- scales_precision ,
921
- linear_class ,
922
- copy_weights ,
923
- )
913
+ # In distributed training, the model may be instantiated
914
+ # on the meta device, in which case there is no need to
915
+ # copy the weights, and doing so will result in an error
916
+ if copy_weights and child .weight .device != torch .device ("meta" ):
917
+ new_linear .weight = child .weight
918
+ return new_linear
919
+
920
+ _replace_with_custom_fn_if_matches_filter (module , replacement_fn , filter_fn )
924
921
925
922
def replace_linear_8da4w (
926
923
module : torch .nn .Module ,
0 commit comments