Skip to content

Commit 8bf4032

Browse files
committed
Add per tensor fp8 conv2d support
Summary: Add fp8 conv2d support, using the same conv3d kernels, by setting the D dimension to 1. 1. unsqueeze both input and weight in dim 2 ( the D dimension) 2. call fp8 conv3d op from fbgemm `torch.ops.fbgemm.f8f8bf16_conv` 3. assert D dimension shape to be 1 and call sequeeze at dim 2: res.squeeze(2) to remove the D dimension Test Plan: python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_unsqueeze_conv2d_weight python test/quantization/quantize_/workflows/float8/test_float8_tensor.py -k test_fp8_conv_variants
1 parent 86af458 commit 8bf4032

File tree

3 files changed

+220
-36
lines changed

3 files changed

+220
-36
lines changed

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 126 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ def __init__(
8686
)
8787
if dim == 3:
8888
self.conv = self.conv.to(memory_format=torch.channels_last_3d)
89+
elif dim == 2:
90+
self.conv = self.conv.to(memory_format=torch.channels_last)
8991

9092
def forward(self, x):
9193
return self.conv(x)
@@ -336,33 +338,43 @@ def _test_fp8_matmul_model(
336338
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
337339
@common_utils.parametrize("compile", [True, False])
338340
@common_utils.parametrize("inference_mode", [True, False])
339-
# only test for 3D conv for now
340-
# Inputs are (N, C_in, C_out, D, H, W)
341+
# test for 2D/3D conv
342+
# Inputs are (N, C_in, C_out, (D, H, W) or
343+
# (N, C_in, C_out, (H, W)
341344
@common_utils.parametrize(
342345
"sizes",
343346
[
344-
(4, 16, 64, 32, 32, 32),
347+
(4, 16, 64, (32, 32, 32)),
348+
(4, 16, 64, (32, 32)),
345349
],
346350
)
347351
def test_fp8_conv_variants(
348352
self,
349353
dtype: torch.dtype,
350354
compile: bool,
351355
inference_mode: bool,
352-
kernel_preference: KernelPreference,
353356
sizes: Tuple,
354357
):
358+
torch.compiler.reset()
355359
granularity = PerTensor()
356360
kernel_preference = KernelPreference.AUTO
357-
N, C_in, C_out, D, H, W = sizes
358-
dim = 3
361+
362+
N, C_in, C_out, spatial_dims = sizes
363+
dim = len(spatial_dims)
364+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
365+
assert dim in convs, f"Unsupported dim: {dim}"
366+
conv_class = convs[dim]
367+
359368
kernel_size = 3
360369

361370
# Note: this is channel last memory format
362-
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
363-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
371+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
372+
if dim == 3:
373+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
374+
else:
375+
assert dim == 2
376+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
364377

365-
# Create a linear layer with bfloat16 dtype
366378
model = ToyConvModel(
367379
dim,
368380
C_in,
@@ -381,9 +393,9 @@ def test_fp8_conv_variants(
381393
kernel_preference=kernel_preference,
382394
)
383395

384-
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
396+
_is_conv = lambda m, fqn: isinstance(m, conv_class)
385397

386-
quantize_(quantized_model, config, filter_fn=_is_conv3d)
398+
quantize_(quantized_model, config, filter_fn=_is_conv)
387399

388400
if compile:
389401
quantized_model = torch.compile(quantized_model, fullgraph=True)
@@ -407,13 +419,16 @@ def test_fp8_conv_variants(
407419
"Requires fbgemm_gpu_genai to be installed",
408420
)
409421
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
410-
# only test for 3D conv for now
411-
# Inputs are (N, C_in, C_out, D, H, W)
422+
# test for 2D/3D conv
423+
# Inputs are (N, C_in, C_out, (D, H, W) or
424+
# (N, C_in, C_out, (H, W)
412425
@common_utils.parametrize(
413426
"sizes",
414427
[
415-
(4, 12, 64, 32, 32, 32),
416-
(4, 16, 12, 32, 32, 32),
428+
(4, 12, 64, (32, 32, 32)),
429+
(4, 16, 12, (32, 32, 32)),
430+
(4, 12, 64, (32, 32)),
431+
(4, 16, 12, (32, 32)),
417432
],
418433
)
419434
def test_fp8_conv_skip_quant(
@@ -426,14 +441,23 @@ def test_fp8_conv_skip_quant(
426441
"""
427442
granularity = PerTensor()
428443
kernel_preference = KernelPreference.AUTO
429-
N, C_in, C_out, D, H, W = sizes
430-
dim = 3
444+
445+
N, C_in, C_out, spatial_dims = sizes
446+
447+
dim = len(spatial_dims)
448+
convs = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
449+
assert dim in convs, f"Unsupported dim: {dim}"
450+
conv_class = convs[dim]
451+
431452
kernel_size = 3
432453

433454
# Note: this is channel last memory format
434-
input_tensor = torch.randn(N, C_in, D, H, W, dtype=dtype, device="cuda")
435-
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
436-
# Create a linear layer with bfloat16 dtype
455+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device="cuda")
456+
if dim == 3:
457+
input_tensor = input_tensor.to(memory_format=torch.channels_last_3d)
458+
else:
459+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
460+
437461
model = ToyConvModel(
438462
dim,
439463
C_in,
@@ -452,9 +476,9 @@ def test_fp8_conv_skip_quant(
452476
kernel_preference=kernel_preference,
453477
)
454478

455-
_is_conv3d = lambda m, fqn: isinstance(m, torch.nn.Conv3d)
479+
_is_conv = lambda m, fqn: isinstance(m, conv_class)
456480

457-
quantize_(quantized_model, config, filter_fn=_is_conv3d)
481+
quantize_(quantized_model, config, filter_fn=_is_conv)
458482
assert not isinstance(quantized_model.conv.weight, Float8Tensor)
459483

460484
output_original = model(input_tensor)
@@ -793,7 +817,6 @@ def test_index_select(self):
793817
],
794818
)
795819
def test_unsqueeze_operation(self, granularity, sizes):
796-
"""Test aten.unsqueeze.default operation on Float8Tensor"""
797820
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
798821
dtype = torch.bfloat16
799822
device = "cuda"
@@ -806,7 +829,7 @@ def test_unsqueeze_operation(self, granularity, sizes):
806829
original_weight = linear.weight
807830
original_shape = original_weight.shape
808831

809-
# Test unsqueeze operation at dim=0 (only supported dimension)
832+
# Test unsqueeze operation at dim=0
810833
unsqueezed_weight = original_weight.unsqueeze(0)
811834

812835
# Verify the unsqueezed tensor has correct shape
@@ -848,6 +871,85 @@ def test_unsqueeze_operation(self, granularity, sizes):
848871

849872
self.assertEqual(unsqueezed_dequant, expected_dequant)
850873

874+
def test_unsqueeze_conv2d_weight(self):
875+
granularity = PerTensor()
876+
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
877+
dtype = torch.bfloat16
878+
device = "cuda"
879+
N, C_in, C_out, spatial_dims = 4, 16, 64, (32, 32)
880+
dim = len(spatial_dims)
881+
kernel_size = 3
882+
883+
input_tensor = torch.randn(N, C_in, *spatial_dims, dtype=dtype, device=device)
884+
input_tensor = input_tensor.to(memory_format=torch.channels_last)
885+
model = ToyConvModel(
886+
dim,
887+
C_in,
888+
C_out,
889+
kernel_size,
890+
bias=False,
891+
padding=0,
892+
dtype=dtype,
893+
device=device,
894+
).eval()
895+
896+
quantized_model = copy.deepcopy(model)
897+
898+
config = Float8DynamicActivationFloat8WeightConfig(
899+
granularity=granularity,
900+
)
901+
902+
_is_conv = lambda m, fqn: isinstance(m, torch.nn.Conv2d)
903+
904+
quantize_(quantized_model, config, filter_fn=_is_conv)
905+
906+
original_weight = quantized_model.conv.weight
907+
original_shape = original_weight.shape
908+
909+
# Test unsqueeze operation at dim=2
910+
unsqueezed_weight = original_weight.unsqueeze(2)
911+
912+
# Verify the unsqueezed tensor has correct shape
913+
original_shape_list = list(original_shape)
914+
expected_shape = original_shape_list[:2] + [1] + original_shape_list[2:]
915+
scale_shape_list = list(original_weight.scale.shape)
916+
expected_scale_shape = scale_shape_list[:2] + [1] + scale_shape_list[2:]
917+
918+
self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape))
919+
# Verify qdata and scale shapes
920+
expected_qdata_shape = expected_shape
921+
922+
self.assertEqual(
923+
unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape)
924+
)
925+
self.assertEqual(
926+
unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape)
927+
)
928+
929+
# Verify block_size is correctly updated
930+
expected_block_size = []
931+
for i in range(len(expected_shape)):
932+
expected_block_size.append(expected_shape[i] // expected_scale_shape[i])
933+
934+
self.assertEqual(unsqueezed_weight.block_size, expected_block_size)
935+
936+
# Test that metadata is preserved
937+
self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config)
938+
self.assertEqual(
939+
unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs
940+
)
941+
self.assertEqual(
942+
unsqueezed_weight.kernel_preference, original_weight.kernel_preference
943+
)
944+
self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype)
945+
946+
# Test numerical correctness
947+
original_dequant = original_weight.dequantize()
948+
unsqueezed_dequant = unsqueezed_weight.dequantize()
949+
expected_dequant = original_dequant.unsqueeze(2)
950+
951+
self.assertEqual(unsqueezed_dequant, expected_dequant)
952+
851953
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
852954
def test_unsqueeze_error_cases(self, granularity):
853955
"""Test error cases for aten.unsqueeze.default operation"""

torchao/quantization/quant_api.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1816,13 +1816,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
18161816
_check_hardware_support(granularity)
18171817
activation_granularity, weight_granularity = granularity
18181818

1819-
if weight.dim() == 5:
1820-
# weights for conv3d
1819+
if weight.dim() in [4, 5]:
1820+
# weights for conv2d or 3d
18211821
assert isinstance(activation_granularity, PerTensor) and isinstance(
18221822
weight_granularity, PerTensor
1823-
), "5D tensor only supports per tensor activation and weight quantization"
1823+
), "4D/5D tensor only supports per tensor activation and weight quantization"
18241824

1825-
# weight dim: (C_out, C_in, K1, K2, K3)
1825+
# conv3d weight dim: (C_out, C_in, K1, K2, K3)
1826+
# conv2d weight dim: (C_out, C_in, K1, K2)
18261827
# skip quantization when either C_out or C_in
18271828
# is not a multiple of 16
18281829
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 89 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def _quantize_and_scaled_conv3d(
537537

538538
# move C_in to last dim
539539
# after permute: (C_out, K1, K2, K3, C_in)
540+
540541
weight_qdata = weight_tensor.qdata.permute([0, 2, 3, 4, 1])
541542

542543
assert act_qdata.is_contiguous() and weight_qdata.is_contiguous(), (
@@ -572,10 +573,71 @@ def _(func, types, args, kwargs):
572573
groups,
573574
) = args
574575
assert not transposed, "transposed conv is not supported currently"
575-
assert tuple(output_padding) == (0, 0, 0), (
576-
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
577-
)
576+
dim = len(output_padding)
577+
assert dim in [2, 3], "Only 2d or 3d convs are supported"
578578
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
579+
580+
if dim == 2:
581+
assert input_tensor.is_contiguous(
582+
memory_format=torch.channels_last
583+
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), (
584+
"Please make sure both activation and weights are in the `channels_last` memory_format"
585+
)
586+
# (N, C, H, W) --> (N, C, 1, H, W)
587+
input_tensor = input_tensor.unsqueeze(2)
588+
weight_tensor = weight_tensor.unsqueeze(2)
589+
assert tuple(output_padding) == (0, 0), (
590+
f"Only (0, 0) is supported for `output_padding`, got: f{output_padding}"
591+
)
592+
padding = [0, *padding]
593+
stride = [1, *stride]
594+
dilation = [1, *dilation]
595+
res = _quantize_and_scaled_conv3d(
596+
input_tensor,
597+
weight_tensor,
598+
bias,
599+
stride,
600+
padding,
601+
dilation,
602+
)
603+
assert res.shape[2] == 1
604+
res = res.squeeze(2)
605+
return res
606+
else:
607+
assert input_tensor.is_contiguous(
608+
memory_format=torch.channels_last_3d
609+
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), (
610+
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
611+
)
612+
assert tuple(output_padding) == (0, 0, 0), (
613+
f"Only (0, 0, 0) is supported for `output_padding`, got: f{output_padding}"
614+
)
615+
return _quantize_and_scaled_conv3d(
616+
input_tensor,
617+
weight_tensor,
618+
bias,
619+
stride,
620+
padding,
621+
dilation,
622+
)
623+
624+
625+
@implements(aten.conv3d.default)
626+
def _(func, types, args, kwargs):
627+
(
628+
input_tensor,
629+
weight_tensor,
630+
bias,
631+
stride,
632+
padding,
633+
dilation,
634+
groups,
635+
) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1])
636+
assert input_tensor.is_contiguous(
637+
memory_format=torch.channels_last_3d
638+
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d), (
639+
"Please make sure both activation and weights are in the `channels_last_3d` memory_format"
640+
)
579641
return _quantize_and_scaled_conv3d(
580642
input_tensor,
581643
weight_tensor,
@@ -586,7 +648,7 @@ def _(func, types, args, kwargs):
586648
)
587649

588650

589-
@implements(aten.conv3d.default)
651+
@implements(aten.conv2d.default)
590652
def _(func, types, args, kwargs):
591653
(
592654
input_tensor,
@@ -596,16 +658,36 @@ def _(func, types, args, kwargs):
596658
padding,
597659
dilation,
598660
groups,
599-
) = fill_defaults(args, 7, [None, [1, 1, 1], [0, 0, 0], [1, 1, 1], 1])
600-
assert groups == 1, f"Only 1 is supported for `groups`, got: {groups}"
601-
return _quantize_and_scaled_conv3d(
661+
) = fill_defaults(args, 7, [None, [1, 1], [0, 0], [1, 1], 1])
662+
# (N, C, H, W) --> (N, C, 1, H, W)
663+
# memory_format of both tensors should be torch.channels_last
664+
# and it should be preserved with unsqueeze(2) (becoming torch.channels_last_3d)
665+
assert input_tensor.is_contiguous(
666+
memory_format=torch.channels_last
667+
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last), (
668+
"Please make sure both activation and weights are in the `channels_last` memory_format"
669+
)
670+
input_tensor = input_tensor.unsqueeze(2)
671+
weight_tensor = weight_tensor.unsqueeze(2)
672+
673+
assert input_tensor.is_contiguous(
674+
memory_format=torch.channels_last_3d
675+
) and weight_tensor.qdata.is_contiguous(memory_format=torch.channels_last_3d)
676+
677+
padding = [0, *padding]
678+
stride = [1, *stride]
679+
dilation = [1, *dilation]
680+
res = _quantize_and_scaled_conv3d(
602681
input_tensor,
603682
weight_tensor,
604683
bias,
605684
stride,
606685
padding,
607686
dilation,
608687
)
688+
assert res.shape[2] == 1
689+
res = res.squeeze(2)
690+
return res
609691

610692

611693
@implements(aten.slice.Tensor)
@@ -837,7 +919,6 @@ def _(func, types, args, kwargs):
837919
@implements(aten.unsqueeze.default)
838920
def _(func, types, args, kwargs):
839921
self, dim = args
840-
assert dim == 0, f"Only dim == 0 is supported, got: {dim}"
841922
qdata = self.qdata.unsqueeze(dim=dim)
842923
scale = self.scale.unsqueeze(dim=dim)
843924
block_size = []

0 commit comments

Comments
 (0)