Skip to content

Commit 88ffcbc

Browse files
committed
updates
1 parent aba1a10 commit 88ffcbc

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

scripts/hf_eval.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
4545
model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cpu", dtype=precision)
4646

4747
if compile:
48-
model = torch.compile(model, mode="max-autotune", fullgraph=True)
48+
model = torch.compile(model, fullgraph=True)
4949

5050
if quantization == "int8dq":
5151
change_linear_weights_to_int8_dqtensors(model)
@@ -57,16 +57,10 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
5757
elif quantization == "autoquant":
5858
model = autoquant(model.to(device=device))
5959
elif quantization == "fp8":
60-
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
61-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
60+
from float8_experimental.inference import quantize_to_float8, ActivationCasting, QuantConfig, ScalingGranularity
6261
model.to(device)
63-
swap_linear_with_float8_linear(
64-
model,
65-
Float8DynamicLinear,
66-
from_float_kwargs={
67-
"pre_quantize_weight": True,
68-
},
69-
)
62+
quantize_to_float8(model, QuantConfig(ActivationCasting.DYNAMIC), scaling_granularity=ScalingGranularity.TensorWise)
63+
7064
pass # no quantization applied, model is already on device and precision dtype.
7165

7266
with torch.no_grad():

torchao/_models/llama/generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ def main(
191191
autoquant,
192192
unwrap_tensor_subclass
193193
)
194-
194+
if "fp8" in quantization:
195+
from float8_experimental.inference import quantize_to_float8, ActivationCasting, QuantConfig
196+
quantize_to_float8(model, QuantConfig(ActivationCasting.DYNAMIC))
195197
if "int8wo" in quantization:
196198
quantize(model, int8_weight_only())
197199
if "int8dq" in quantization:

0 commit comments

Comments
 (0)