Skip to content

Commit bfd78c0

Browse files
committed
fix memory being held by autograd
1 parent e559f2a commit bfd78c0

File tree

6 files changed

+122
-113
lines changed

6 files changed

+122
-113
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,33 @@
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
float8_weight_only,
1212
)
13+
from torch.testing._internal import common_utils
1314
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1415

1516
import torch
1617
import unittest
1718
import tempfile
1819

20+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
21+
22+
23+
def get_quantization_functions(do_sparse: bool, do_int4: bool):
24+
base_functions = [
25+
int8_weight_only(),
26+
int8_dynamic_activation_int4_weight(),
27+
int8_dynamic_activation_int8_weight(),
28+
]
29+
if do_int4:
30+
base_functions.append(int4_weight_only(group_size=32))
31+
32+
if do_sparse:
33+
base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight())
34+
35+
if is_cuda_8_9:
36+
base_functions.append(float8_weight_only())
37+
38+
return base_functions
39+
1940

2041
class TestAffineQuantized(TestCase):
2142
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -38,36 +59,36 @@ def test_tensor_core_layout_transpose(self):
3859
self.assertEqual(aqt_shape, shape)
3960

4061
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
41-
def test_weights_only(self):
42-
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(),
43-
int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]:
44-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
45-
ql = apply_quant(l)
46-
with tempfile.NamedTemporaryFile() as f:
47-
torch.save(ql.state_dict(), f)
48-
f.seek(0)
49-
# `weights_only=True` is enabled for torch 2.5+
50-
if TORCH_VERSION_AT_LEAST_2_5:
51-
_ = torch.load(f, weights_only=True)
52-
else:
53-
_ = torch.load(f, weights_only=False)
62+
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
63+
def test_weights_only(self, apply_quant):
64+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
65+
ql = apply_quant(l)
66+
with tempfile.NamedTemporaryFile() as f:
67+
torch.save(ql.state_dict(), f)
68+
f.seek(0)
69+
# `weights_only=True` is enabled for torch 2.5+
70+
if TORCH_VERSION_AT_LEAST_2_5:
71+
_ = torch.load(f, weights_only=True)
72+
else:
73+
_ = torch.load(f, weights_only=False)
5474

5575
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
56-
def test_to_device(self):
57-
from torchao.quantization import quantize_
58-
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
59-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
60-
ql = apply_quant(l)
61-
ql.to("cuda")
76+
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
77+
def test_to_device(self, apply_quant):
78+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
79+
ql = apply_quant(l)
80+
ql.to("cuda")
81+
82+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
83+
ql = apply_quant(l)
84+
ql.to(device="cuda")
6285

63-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
64-
ql = apply_quant(l)
65-
ql.to(device="cuda")
86+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
87+
ql = apply_quant(l)
88+
ql.cuda()
6689

67-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
68-
ql = apply_quant(l)
69-
ql.cuda()
7090

91+
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
7192

7293
if __name__ == "__main__":
7394
run_tests()

test/dtypes/test_affine_quantized_float.py

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
)
1414
from torch._inductor.test_case import TestCase as InductorTestCase
1515
from torch.testing._internal import common_utils
16-
from torch.testing._internal.common_utils import (
17-
TestCase,
18-
run_tests,
19-
)
2016
from torch._dynamo.testing import CompileCounterWithBackend
2117

2218
from torchao.quantization import (
@@ -54,46 +50,9 @@ def forward(self, x):
5450
return x
5551

5652

57-
class TestAffineQuantizedFloat8Basic(TestCase):
58-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
59-
def test_tensor_core_layout_transpose(self):
60-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
61-
t = l.weight
62-
shape = t.shape
63-
apply_float8_weight_only_quant = float8_weight_only()
64-
ql = apply_float8_weight_only_quant(l)
65-
aqt = ql.weight
66-
aqt_shape = aqt.shape
67-
assert aqt_shape == shape
68-
69-
# transpose shape test
70-
for _ in range(10):
71-
t = t.t()
72-
aqt = aqt.t()
73-
shape = t.shape
74-
aqt_shape = aqt.shape
75-
assert aqt_shape == shape
76-
77-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
78-
def test_weights_only_save_load(self):
79-
with torch.no_grad():
80-
for apply_quant in [float8_weight_only()]:
81-
# TODO Fails when l requires grad
82-
l = torch.nn.Linear(128, 256).eval().to(torch.bfloat16).to("cuda")
83-
ql = apply_quant(l)
84-
with tempfile.NamedTemporaryFile() as f:
85-
torch.save(ql.state_dict(), f)
86-
f.seek(0)
87-
# `weights_only=True` is enabled for torch 2.5+
88-
if TORCH_VERSION_AT_LEAST_2_5:
89-
_ = torch.load(f, weights_only=True)
90-
else:
91-
_ = torch.load(f, weights_only=False)
92-
93-
9453
class TestAffineQuantizedFloat8Compile(InductorTestCase):
9554
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
96-
@unittest.skipIf(not is_cuda_8_9, "Need H100")
55+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
9756
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
9857
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
9958
@common_utils.parametrize("compile", [True, False])
@@ -108,7 +67,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
10867
((64, 256), 512, 128),
10968
],
11069
)
111-
def test_dynamic_fp8_linear(
70+
def test_fp8_linear_variants(
11271
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
11372
):
11473
M, N, K = sizes
@@ -132,7 +91,10 @@ def test_dynamic_fp8_linear(
13291
output_original = model(input_tensor)
13392
output_quantized = quantized_model(input_tensor)
13493

135-
assert compute_error(output_original, output_quantized) > 20, "Error is too low"
94+
error = compute_error(output_original, output_quantized)
95+
assert (
96+
compute_error(output_original, output_quantized) > 20
97+
), f"Quantization error is too high got a SQNR of {error}"
13698

13799

138100
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ def from_hp_to_intx(
214214
block_size: Tuple[int, ...],
215215
target_dtype: torch.dtype,
216216
quant_min: Optional[int] = None,
217-
quant_max: Optional[int] = None,
217+
quant_max: Optional[int] = None,
218218
eps: Optional[float] = None,
219219
scale_dtype: Optional[torch.dtype] = None,
220220
zero_point_dtype: Optional[torch.dtype] = None,
221221
preserve_zero: bool = True,
222-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
222+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
223223
layout_type: LayoutType = PlainLayoutType(),
224224
use_hqq: bool = False,
225225
):
@@ -237,6 +237,9 @@ def from_hp_to_intx(
237237
data = data.to(target_dtype)
238238
else:
239239
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
240+
# choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None
241+
if zero_point_domain is None:
242+
zero_point = None
240243
data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
241244
# Note: output will be uint8 tensor for sub byte tensors for now
242245

@@ -262,7 +265,7 @@ def from_hp_to_intx_static(
262265
block_size: Tuple[int, ...],
263266
target_dtype: torch.dtype,
264267
quant_min: Optional[int] = None,
265-
quant_max: Optional[int] = None,
268+
quant_max: Optional[int] = None,
266269
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
267270
layout_type: LayoutType = PlainLayoutType(),
268271
):
@@ -291,8 +294,8 @@ def from_hp_to_floatx(
291294
input_float: torch.Tensor,
292295
block_size: Tuple[int, ...],
293296
target_dtype: torch.dtype,
294-
scale_dtype: Optional[torch.dtype] = None,
295-
layout_type: LayoutType = PlainLayoutType(),
297+
scale_dtype: Optional[torch.dtype],
298+
layout_type: LayoutType,
296299
):
297300

298301
if target_dtype in FP8_TYPES:
@@ -400,10 +403,8 @@ def extra_repr(self):
400403

401404
@dataclass(frozen=True)
402405
class Float8LayoutType(LayoutType):
403-
mm_config: ScaledMMConfig
406+
mm_config: Optional[ScaledMMConfig]
404407

405-
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
406-
return input
407408

408409
@register_layout_cls(PlainLayoutType)
409410
class PlainAQTLayout(AQTLayout):
@@ -602,9 +603,18 @@ def _apply_fn_to_data(self, fn):
602603
fn(self.scale)
603604
return self
604605

606+
def to(self, *args, **kwargs):
607+
kwargs = self._get_to_kwargs(*args, **kwargs)
608+
return self.__class__(
609+
self.float8_data.to(kwargs["device"]),
610+
self.scale.to(kwargs["device"]),
611+
self.transposed,
612+
self.layout_type,
613+
)
614+
605615
def __tensor_flatten__(self):
606616
return ["float8_data", "scale"], [self.transposed, self.layout_type]
607-
617+
608618
@classmethod
609619
def __tensor_unflatten__(
610620
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
@@ -621,6 +631,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
621631
return return_and_correct_aliasing(
622632
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
623633
)
634+
if func is aten.clone.default:
635+
return return_and_correct_aliasing(
636+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
637+
)
624638
if func is aten.t.default:
625639
"""we don't need to repack the weight and just rely on external
626640
shape being changed and record the status of transpose/no-transpose
@@ -650,6 +664,7 @@ def from_plain(
650664
):
651665
""" Main entrypoint for constructing Float8Layout Tensor"""
652666
assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}"
667+
assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}"
653668
return cls(data, scale, False, layout_type)
654669

655670
def __repr__(self):
@@ -1027,14 +1042,14 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
10271042

10281043

10291044
def _linear_fp_act_fp8_tensor_wise_weight_check(
1030-
input_tensor: torch.Tensor,
1031-
weight_tensor: AffineQuantizedTensor,
1045+
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
1046+
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
10321047
bias: Optional[torch.Tensor],
10331048
) -> bool:
1034-
def check_aqt_tensorwise(aqt: AffineQuantizedTensor) -> bool:
1049+
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
10351050
return (
10361051
isinstance(aqt, AffineQuantizedTensor) and
1037-
isinstance(aqt.layout_tensor, Float8AQTLayout)
1052+
isinstance(aqt.layout_type, Float8LayoutType)
10381053
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
10391054
and aqt.shape == aqt.block_size
10401055
)
@@ -1047,7 +1062,7 @@ def _linear_fp_act_fp8_weight_impl(
10471062
bias: Optional[torch.Tensor],
10481063
):
10491064
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
1050-
from torchao.float8.inference import cast_to_float8_e4m3_inference, preprocess_data
1065+
from torchao.float8.inference import preprocess_data
10511066
from torchao.float8.float8_tensor import ScaledMMConfig
10521067
from torchao.float8.float8_python_api import addmm_float8_unwrapped
10531068

@@ -1066,7 +1081,7 @@ def _linear_fp_act_fp8_weight_impl(
10661081
# Handle case where input tensor is more than 2D
10671082
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
10681083
input_scale = input_tensor.layout_tensor.scale
1069-
if input_scale.dim() >= 2:
1084+
if input_scale.dim() > 2:
10701085
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
10711086

10721087
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

torchao/float8/inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,14 @@ def quantize_to_float8(
243243
module_filter_fn=module_filter_fn,
244244
)
245245

246+
246247
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
247248

248-
def preprocess_data(a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig) -> Tuple[torch.Tensor, torch.Tensor]:
249-
""" Preprocess the inner fp8 data tensors for admmm
249+
250+
def preprocess_data(
251+
a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig
252+
) -> Tuple[torch.Tensor, torch.Tensor]:
253+
"""Preprocess the inner fp8 data tensors for admmm
250254
Args:
251255
a_data: Input tensor A.
252256
b_data: Input tensor B.

torchao/quantization/quant_api.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -493,12 +493,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
493493
return int8_dynamic_activation_int8_weight(layout_type=SemiSparseLayoutType())
494494

495495

496-
def float8_weight_only(target_dtype: torch.dtype = torch.float8_e4m3fn):
496+
def float8_weight_only(weight_dtype: torch.dtype = torch.float8_e4m3fn):
497497
"""
498498
Applies float8 weight-only symmetric per-channel quantization to linear layers.
499499
500500
Args:
501-
target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
501+
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
502502
503503
Note:
504504
The actual matmul will be computed in original precision of the weight tensor.
@@ -511,38 +511,41 @@ def apply_float8wo_quant(weight):
511511
return to_affine_quantized_floatx(
512512
input_float=weight,
513513
block_size=block_size,
514-
target_dtype=target_dtype,
514+
target_dtype=weight_dtype,
515+
scale_dtype=None,
515516
layout_type=Float8LayoutType(mm_config=None),
516517
)
517518

518519
return _get_linear_subclass_inserter(apply_float8wo_quant)
519520

520521

521522
def float8_dynamic_activation_float8_weight(
522-
target_dtype: torch.dtype = torch.float8_e4m3fn,
523523
activation_dtype: torch.dtype = torch.float8_e4m3fn,
524-
mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True)
524+
weight_dtype: torch.dtype = torch.float8_e4m3fn,
525+
mm_config: Optional[ScaledMMConfig] = None
525526
):
526527
"""
527528
Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers.
528529
529530
Args:
530-
target_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
531531
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m3fn.
532+
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m3fn.
532533
mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
533534
534535
"""
535-
536536
from torchao.dtypes import to_affine_quantized_floatx
537537

538+
if mm_config is None:
539+
mm_config = ScaledMMConfig(use_fast_accum=True)
540+
538541
#TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling
539542
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
540543
quantized_weight = to_affine_quantized_floatx(
541544
input_float=weight,
542545
block_size=weight.shape,
543-
target_dtype=target_dtype,
546+
target_dtype=weight_dtype,
544547
scale_dtype=torch.float32,
545-
layout_type=Float8LayoutType(mm_config=None),
548+
layout_type=Float8LayoutType(mm_config=mm_config),
546549
)
547550

548551
def input_quant_func(x: torch.Tensor):
@@ -551,7 +554,7 @@ def input_quant_func(x: torch.Tensor):
551554
block_size=x.shape,
552555
target_dtype=activation_dtype,
553556
scale_dtype=torch.float32,
554-
layout_type=Float8LayoutType(mm_config=None),
557+
layout_type=Float8LayoutType(mm_config=None), # Config is stored on weight
555558
)
556559
return activation
557560

0 commit comments

Comments
 (0)