Skip to content

Commit f3fc096

Browse files
authored
Switch to new ao quant api for 8da4w (#8501)
1 parent 93838e8 commit f3fc096

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

examples/models/llama/source_transformation/quantize.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,10 @@ def quantize( # noqa C901
119119
# Check for required args
120120
if group_size is None:
121121
raise Exception("For 8da4w quantization, group size must be specified.")
122-
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
123122

124-
model = Int8DynActInt4WeightQuantizer(
125-
precision=torch_dtype, groupsize=group_size
126-
).quantize(model)
123+
from torchao.quantization import int8_dynamic_activation_int4_weight, quantize_
124+
125+
quantize_(model, int8_dynamic_activation_int4_weight(group_size=group_size))
127126

128127
if verbose:
129128
print("quantized model:", model)
@@ -663,7 +662,7 @@ def convert_for_runtime(self) -> nn.Module:
663662
def quantized_model(self) -> nn.Module:
664663
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
665664
self.convert_for_runtime()
666-
self.mod.load_state_dict(model_updated_state_dict)
665+
self.mod.load_state_dict(model_updated_state_dict, assign=True)
667666
return self.mod
668667

669668

0 commit comments

Comments
 (0)