Skip to content

Commit ffeeb9a

Browse files
committed
fix memory being held by autograd
1 parent e559f2a commit ffeeb9a

File tree

6 files changed

+108
-108
lines changed

6 files changed

+108
-108
lines changed

test/dtypes/test_affine_quantized.py

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,32 @@
1010
int8_dynamic_activation_int8_semi_sparse_weight,
1111
float8_weight_only,
1212
)
13+
from torch.testing._internal import common_utils
1314
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
1415

1516
import torch
1617
import unittest
1718
import tempfile
1819

20+
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
21+
22+
23+
def get_quantization_functions(do_sparse: bool, do_int4: bool):
24+
base_functions = [
25+
int8_weight_only(),
26+
int8_dynamic_activation_int4_weight(),
27+
int8_dynamic_activation_int8_weight(),
28+
]
29+
if do_int4:
30+
base_functions.append(int4_weight_only(group_size=32))
31+
if do_sparse:
32+
base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight())
33+
34+
if is_cuda_8_9 and do_float8: # You need to define this function
35+
base_functions.append(float8_weight_only())
36+
37+
return base_functions
38+
1939

2040
class TestAffineQuantized(TestCase):
2141
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@@ -38,36 +58,36 @@ def test_tensor_core_layout_transpose(self):
3858
self.assertEqual(aqt_shape, shape)
3959

4060
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
41-
def test_weights_only(self):
42-
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(),
43-
int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight(), float8_weight_only()]:
44-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
45-
ql = apply_quant(l)
46-
with tempfile.NamedTemporaryFile() as f:
47-
torch.save(ql.state_dict(), f)
48-
f.seek(0)
49-
# `weights_only=True` is enabled for torch 2.5+
50-
if TORCH_VERSION_AT_LEAST_2_5:
51-
_ = torch.load(f, weights_only=True)
52-
else:
53-
_ = torch.load(f, weights_only=False)
61+
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
62+
def test_weights_only(self, apply_quant):
63+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
64+
ql = apply_quant(l)
65+
with tempfile.NamedTemporaryFile() as f:
66+
torch.save(ql.state_dict(), f)
67+
f.seek(0)
68+
# `weights_only=True` is enabled for torch 2.5+
69+
if TORCH_VERSION_AT_LEAST_2_5:
70+
_ = torch.load(f, weights_only=True)
71+
else:
72+
_ = torch.load(f, weights_only=False)
5473

5574
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
56-
def test_to_device(self):
57-
from torchao.quantization import quantize_
58-
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
59-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
60-
ql = apply_quant(l)
61-
ql.to("cuda")
75+
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
76+
def test_to_device(self, apply_quant):
77+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
78+
ql = apply_quant(l)
79+
ql.to("cuda")
80+
81+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
82+
ql = apply_quant(l)
83+
ql.to(device="cuda")
6284

63-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
64-
ql = apply_quant(l)
65-
ql.to(device="cuda")
85+
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
86+
ql = apply_quant(l)
87+
ql.cuda()
6688

67-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
68-
ql = apply_quant(l)
69-
ql.cuda()
7089

90+
common_utils.instantiate_parametrized_tests(TestAffineQuantized)
7191

7292
if __name__ == "__main__":
7393
run_tests()

test/dtypes/test_affine_quantized_float.py

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
)
1414
from torch._inductor.test_case import TestCase as InductorTestCase
1515
from torch.testing._internal import common_utils
16-
from torch.testing._internal.common_utils import (
17-
TestCase,
18-
run_tests,
19-
)
2016
from torch._dynamo.testing import CompileCounterWithBackend
2117

2218
from torchao.quantization import (
@@ -54,46 +50,9 @@ def forward(self, x):
5450
return x
5551

5652

57-
class TestAffineQuantizedFloat8Basic(TestCase):
58-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
59-
def test_tensor_core_layout_transpose(self):
60-
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
61-
t = l.weight
62-
shape = t.shape
63-
apply_float8_weight_only_quant = float8_weight_only()
64-
ql = apply_float8_weight_only_quant(l)
65-
aqt = ql.weight
66-
aqt_shape = aqt.shape
67-
assert aqt_shape == shape
68-
69-
# transpose shape test
70-
for _ in range(10):
71-
t = t.t()
72-
aqt = aqt.t()
73-
shape = t.shape
74-
aqt_shape = aqt.shape
75-
assert aqt_shape == shape
76-
77-
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
78-
def test_weights_only_save_load(self):
79-
with torch.no_grad():
80-
for apply_quant in [float8_weight_only()]:
81-
# TODO Fails when l requires grad
82-
l = torch.nn.Linear(128, 256).eval().to(torch.bfloat16).to("cuda")
83-
ql = apply_quant(l)
84-
with tempfile.NamedTemporaryFile() as f:
85-
torch.save(ql.state_dict(), f)
86-
f.seek(0)
87-
# `weights_only=True` is enabled for torch 2.5+
88-
if TORCH_VERSION_AT_LEAST_2_5:
89-
_ = torch.load(f, weights_only=True)
90-
else:
91-
_ = torch.load(f, weights_only=False)
92-
93-
9453
class TestAffineQuantizedFloat8Compile(InductorTestCase):
9554
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
96-
@unittest.skipIf(not is_cuda_8_9, "Need H100")
55+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
9756
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
9857
@common_utils.parametrize("mode", ["dynamic", "weight-only"])
9958
@common_utils.parametrize("compile", [True, False])
@@ -108,7 +67,7 @@ class TestAffineQuantizedFloat8Compile(InductorTestCase):
10867
((64, 256), 512, 128),
10968
],
11069
)
111-
def test_dynamic_fp8_linear(
70+
def test_fp8_linear_variants(
11271
self, dtype: torch.dtype, mode: str, compile: bool, sizes: tuple
11372
):
11473
M, N, K = sizes
@@ -132,7 +91,10 @@ def test_dynamic_fp8_linear(
13291
output_original = model(input_tensor)
13392
output_quantized = quantized_model(input_tensor)
13493

135-
assert compute_error(output_original, output_quantized) > 20, "Error is too low"
94+
error = compute_error(output_original, output_quantized)
95+
assert (
96+
compute_error(output_original, output_quantized) > 20
97+
), f"Quantization error is too high got a SQNR of {error}"
13698

13799

138100
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -214,12 +214,12 @@ def from_hp_to_intx(
214214
block_size: Tuple[int, ...],
215215
target_dtype: torch.dtype,
216216
quant_min: Optional[int] = None,
217-
quant_max: Optional[int] = None,
217+
quant_max: Optional[int] = None,
218218
eps: Optional[float] = None,
219219
scale_dtype: Optional[torch.dtype] = None,
220220
zero_point_dtype: Optional[torch.dtype] = None,
221221
preserve_zero: bool = True,
222-
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
222+
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
223223
layout_type: LayoutType = PlainLayoutType(),
224224
use_hqq: bool = False,
225225
):
@@ -237,6 +237,8 @@ def from_hp_to_intx(
237237
data = data.to(target_dtype)
238238
else:
239239
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain)
240+
if zero_point_domain is None:
241+
zero_point = None
240242
data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain)
241243
# Note: output will be uint8 tensor for sub byte tensors for now
242244

@@ -262,7 +264,7 @@ def from_hp_to_intx_static(
262264
block_size: Tuple[int, ...],
263265
target_dtype: torch.dtype,
264266
quant_min: Optional[int] = None,
265-
quant_max: Optional[int] = None,
267+
quant_max: Optional[int] = None,
266268
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
267269
layout_type: LayoutType = PlainLayoutType(),
268270
):
@@ -291,8 +293,8 @@ def from_hp_to_floatx(
291293
input_float: torch.Tensor,
292294
block_size: Tuple[int, ...],
293295
target_dtype: torch.dtype,
294-
scale_dtype: Optional[torch.dtype] = None,
295-
layout_type: LayoutType = PlainLayoutType(),
296+
scale_dtype: Optional[torch.dtype],
297+
layout_type: LayoutType,
296298
):
297299

298300
if target_dtype in FP8_TYPES:
@@ -400,10 +402,8 @@ def extra_repr(self):
400402

401403
@dataclass(frozen=True)
402404
class Float8LayoutType(LayoutType):
403-
mm_config: ScaledMMConfig
405+
mm_config: Optional[ScaledMMConfig]
404406

405-
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
406-
return input
407407

408408
@register_layout_cls(PlainLayoutType)
409409
class PlainAQTLayout(AQTLayout):
@@ -602,9 +602,18 @@ def _apply_fn_to_data(self, fn):
602602
fn(self.scale)
603603
return self
604604

605+
def to(self, *args, **kwargs):
606+
kwargs = self._get_to_kwargs(*args, **kwargs)
607+
return self.__class__(
608+
self.float8_data.to(kwargs["device"]),
609+
self.scale.to(kwargs["device"]),
610+
self.transposed,
611+
self.layout_type,
612+
)
613+
605614
def __tensor_flatten__(self):
606615
return ["float8_data", "scale"], [self.transposed, self.layout_type]
607-
616+
608617
@classmethod
609618
def __tensor_unflatten__(
610619
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
@@ -621,6 +630,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
621630
return return_and_correct_aliasing(
622631
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
623632
)
633+
if func is aten.clone.default:
634+
return return_and_correct_aliasing(
635+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
636+
)
624637
if func is aten.t.default:
625638
"""we don't need to repack the weight and just rely on external
626639
shape being changed and record the status of transpose/no-transpose
@@ -650,6 +663,7 @@ def from_plain(
650663
):
651664
""" Main entrypoint for constructing Float8Layout Tensor"""
652665
assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}"
666+
assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}"
653667
return cls(data, scale, False, layout_type)
654668

655669
def __repr__(self):
@@ -1027,14 +1041,14 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias):
10271041

10281042

10291043
def _linear_fp_act_fp8_tensor_wise_weight_check(
1030-
input_tensor: torch.Tensor,
1031-
weight_tensor: AffineQuantizedTensor,
1044+
input_tensor: Union[torch.Tensor, AffineQuantizedTensor],
1045+
weight_tensor: Union[torch.Tensor, AffineQuantizedTensor],
10321046
bias: Optional[torch.Tensor],
10331047
) -> bool:
1034-
def check_aqt_tensorwise(aqt: AffineQuantizedTensor) -> bool:
1048+
def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool:
10351049
return (
10361050
isinstance(aqt, AffineQuantizedTensor) and
1037-
isinstance(aqt.layout_tensor, Float8AQTLayout)
1051+
isinstance(aqt.layout_type, Float8LayoutType)
10381052
and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]
10391053
and aqt.shape == aqt.block_size
10401054
)
@@ -1047,7 +1061,7 @@ def _linear_fp_act_fp8_weight_impl(
10471061
bias: Optional[torch.Tensor],
10481062
):
10491063
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
1050-
from torchao.float8.inference import cast_to_float8_e4m3_inference, preprocess_data
1064+
from torchao.float8.inference import preprocess_data
10511065
from torchao.float8.float8_tensor import ScaledMMConfig
10521066
from torchao.float8.float8_python_api import addmm_float8_unwrapped
10531067

@@ -1066,7 +1080,7 @@ def _linear_fp_act_fp8_weight_impl(
10661080
# Handle case where input tensor is more than 2D
10671081
inpt_data = inpt_data.reshape(-1, input_tensor.shape[-1])
10681082
input_scale = input_tensor.layout_tensor.scale
1069-
if input_scale.dim() >= 2:
1083+
if input_scale.dim() > 2:
10701084
input_scale = input_scale.reshape(-1, input_scale.shape[-1])
10711085

10721086
inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config)

torchao/float8/inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,14 @@ def quantize_to_float8(
243243
module_filter_fn=module_filter_fn,
244244
)
245245

246+
246247
from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul
247248

248-
def preprocess_data(a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig) -> Tuple[torch.Tensor, torch.Tensor]:
249-
""" Preprocess the inner fp8 data tensors for admmm
249+
250+
def preprocess_data(
251+
a_data: torch.Tensor, b_data: torch.Tensor, scaled_mm_config: ScaledMMConfig
252+
) -> Tuple[torch.Tensor, torch.Tensor]:
253+
"""Preprocess the inner fp8 data tensors for admmm
250254
Args:
251255
a_data: Input tensor A.
252256
b_data: Input tensor B.

torchao/quantization/quant_api.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -512,16 +512,17 @@ def apply_float8wo_quant(weight):
512512
input_float=weight,
513513
block_size=block_size,
514514
target_dtype=target_dtype,
515+
scale_dtype=None,
515516
layout_type=Float8LayoutType(mm_config=None),
516517
)
517518

518519
return _get_linear_subclass_inserter(apply_float8wo_quant)
519520

520521

521522
def float8_dynamic_activation_float8_weight(
522-
target_dtype: torch.dtype = torch.float8_e4m3fn,
523523
activation_dtype: torch.dtype = torch.float8_e4m3fn,
524-
mm_config: ScaledMMConfig = ScaledMMConfig(use_fast_accum=True)
524+
weight_dtype: torch.dtype = torch.float8_e4m3fn,
525+
mm_config: Optional[ScaledMMConfig] = None
525526
):
526527
"""
527528
Applies float8 dynamic symmetric per-tensor quantization to both activations and weights of linear layers.
@@ -532,17 +533,19 @@ def float8_dynamic_activation_float8_weight(
532533
mm_config (ScaledMMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
533534
534535
"""
535-
536536
from torchao.dtypes import to_affine_quantized_floatx
537537

538+
if mm_config is None:
539+
mm_config = ScaledMMConfig(use_fast_accum=True)
540+
538541
#TODO we are hardcoding TensorWise scaling, will follow up PR for Tensorwise scaling
539542
def apply_float8_dynamic_activation_quant(weight: torch.Tensor):
540543
quantized_weight = to_affine_quantized_floatx(
541544
input_float=weight,
542545
block_size=weight.shape,
543-
target_dtype=target_dtype,
546+
target_dtype=weight_dtype,
544547
scale_dtype=torch.float32,
545-
layout_type=Float8LayoutType(mm_config=None),
548+
layout_type=Float8LayoutType(mm_config=mm_config),
546549
)
547550

548551
def input_quant_func(x: torch.Tensor):
@@ -551,7 +554,7 @@ def input_quant_func(x: torch.Tensor):
551554
block_size=x.shape,
552555
target_dtype=activation_dtype,
553556
scale_dtype=torch.float32,
554-
layout_type=Float8LayoutType(mm_config=None),
557+
layout_type=Float8LayoutType(mm_config=mm_config),
555558
)
556559
return activation
557560

0 commit comments

Comments
 (0)