@@ -87,6 +87,8 @@ def __init__(
8787 )
8888 if dim == 3 :
8989 self .conv = self .conv .to (memory_format = torch .channels_last_3d )
90+ elif dim == 2 :
91+ self .conv = self .conv .to (memory_format = torch .channels_last )
9092
9193 def forward (self , x ):
9294 return self .conv (x )
@@ -337,33 +339,43 @@ def _test_fp8_matmul_model(
337339 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
338340 @common_utils .parametrize ("compile" , [True , False ])
339341 @common_utils .parametrize ("inference_mode" , [True , False ])
340- # only test for 3D conv for now
341- # Inputs are (N, C_in, C_out, D, H, W)
342+ # test for 2D/3D conv
343+ # Inputs are (N, C_in, C_out, (D, H, W) or
344+ # (N, C_in, C_out, (H, W)
342345 @common_utils .parametrize (
343346 "sizes" ,
344347 [
345- (4 , 16 , 64 , 32 , 32 , 32 ),
348+ (4 , 16 , 64 , (32 , 32 , 32 )),
349+ (4 , 16 , 64 , (32 , 32 )),
346350 ],
347351 )
348352 def test_fp8_conv_variants (
349353 self ,
350354 dtype : torch .dtype ,
351355 compile : bool ,
352356 inference_mode : bool ,
353- kernel_preference : KernelPreference ,
354357 sizes : Tuple ,
355358 ):
359+ torch .compiler .reset ()
356360 granularity = PerTensor ()
357361 kernel_preference = KernelPreference .AUTO
358- N , C_in , C_out , D , H , W = sizes
359- dim = 3
362+
363+ N , C_in , C_out , spatial_dims = sizes
364+ dim = len (spatial_dims )
365+ convs = {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
366+ assert dim in convs , f"Unsupported dim: { dim } "
367+ conv_class = convs [dim ]
368+
360369 kernel_size = 3
361370
362371 # Note: this is channel last memory format
363- input_tensor = torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
364- input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
372+ input_tensor = torch .randn (N , C_in , * spatial_dims , dtype = dtype , device = "cuda" )
373+ if dim == 3 :
374+ input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
375+ else :
376+ assert dim == 2
377+ input_tensor = input_tensor .to (memory_format = torch .channels_last )
365378
366- # Create a linear layer with bfloat16 dtype
367379 model = ToyConvModel (
368380 dim ,
369381 C_in ,
@@ -382,9 +394,9 @@ def test_fp8_conv_variants(
382394 kernel_preference = kernel_preference ,
383395 )
384396
385- _is_conv3d = lambda m , fqn : isinstance (m , torch . nn . Conv3d )
397+ _is_conv = lambda m , fqn : isinstance (m , conv_class )
386398
387- quantize_ (quantized_model , config , filter_fn = _is_conv3d )
399+ quantize_ (quantized_model , config , filter_fn = _is_conv )
388400
389401 if compile :
390402 quantized_model = torch .compile (quantized_model , fullgraph = True )
@@ -408,13 +420,16 @@ def test_fp8_conv_variants(
408420 "Requires fbgemm_gpu_genai to be installed" ,
409421 )
410422 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
411- # only test for 3D conv for now
412- # Inputs are (N, C_in, C_out, D, H, W)
423+ # test for 2D/3D conv
424+ # Inputs are (N, C_in, C_out, (D, H, W) or
425+ # (N, C_in, C_out, (H, W)
413426 @common_utils .parametrize (
414427 "sizes" ,
415428 [
416- (4 , 12 , 64 , 32 , 32 , 32 ),
417- (4 , 16 , 12 , 32 , 32 , 32 ),
429+ (4 , 12 , 64 , (32 , 32 , 32 )),
430+ (4 , 16 , 12 , (32 , 32 , 32 )),
431+ (4 , 12 , 64 , (32 , 32 )),
432+ (4 , 16 , 12 , (32 , 32 )),
418433 ],
419434 )
420435 def test_fp8_conv_skip_quant (
@@ -427,14 +442,23 @@ def test_fp8_conv_skip_quant(
427442 """
428443 granularity = PerTensor ()
429444 kernel_preference = KernelPreference .AUTO
430- N , C_in , C_out , D , H , W = sizes
431- dim = 3
445+
446+ N , C_in , C_out , spatial_dims = sizes
447+
448+ dim = len (spatial_dims )
449+ convs = {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
450+ assert dim in convs , f"Unsupported dim: { dim } "
451+ conv_class = convs [dim ]
452+
432453 kernel_size = 3
433454
434455 # Note: this is channel last memory format
435- input_tensor = torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
436- input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
437- # Create a linear layer with bfloat16 dtype
456+ input_tensor = torch .randn (N , C_in , * spatial_dims , dtype = dtype , device = "cuda" )
457+ if dim == 3 :
458+ input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
459+ else :
460+ input_tensor = input_tensor .to (memory_format = torch .channels_last )
461+
438462 model = ToyConvModel (
439463 dim ,
440464 C_in ,
@@ -453,9 +477,9 @@ def test_fp8_conv_skip_quant(
453477 kernel_preference = kernel_preference ,
454478 )
455479
456- _is_conv3d = lambda m , fqn : isinstance (m , torch . nn . Conv3d )
480+ _is_conv = lambda m , fqn : isinstance (m , conv_class )
457481
458- quantize_ (quantized_model , config , filter_fn = _is_conv3d )
482+ quantize_ (quantized_model , config , filter_fn = _is_conv )
459483 assert not isinstance (quantized_model .conv .weight , Float8Tensor )
460484
461485 output_original = model (input_tensor )
@@ -832,7 +856,6 @@ def test_index_select(self):
832856 ],
833857 )
834858 def test_unsqueeze_operation (self , granularity , sizes ):
835- """Test aten.unsqueeze.default operation on Float8Tensor"""
836859 config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
837860 dtype = torch .bfloat16
838861 device = "cuda"
@@ -845,7 +868,7 @@ def test_unsqueeze_operation(self, granularity, sizes):
845868 original_weight = linear .weight
846869 original_shape = original_weight .shape
847870
848- # Test unsqueeze operation at dim=0 (only supported dimension)
871+ # Test unsqueeze operation at dim=0
849872 unsqueezed_weight = original_weight .unsqueeze (0 )
850873
851874 # Verify the unsqueezed tensor has correct shape
@@ -887,22 +910,84 @@ def test_unsqueeze_operation(self, granularity, sizes):
887910
888911 self .assertEqual (unsqueezed_dequant , expected_dequant )
889912
890- @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
891- def test_unsqueeze_error_cases (self , granularity ):
892- """Test error cases for aten.unsqueeze.default operation"""
913+ def test_unsqueeze_conv2d_weight (self ):
914+ granularity = PerTensor ()
893915 config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
894916 dtype = torch .bfloat16
895917 device = "cuda"
918+ N , C_in , C_out , spatial_dims = 4 , 16 , 64 , (32 , 32 )
919+ dim = len (spatial_dims )
920+ kernel_size = 3
896921
897- # Create a linear layer and quantize it
898- linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype , device = device )
899- quantize_ (linear , config )
922+ input_tensor = torch .randn (N , C_in , * spatial_dims , dtype = dtype , device = device )
923+ input_tensor = input_tensor .to (memory_format = torch .channels_last )
924+ model = ToyConvModel (
925+ dim ,
926+ C_in ,
927+ C_out ,
928+ kernel_size ,
929+ bias = False ,
930+ padding = 0 ,
931+ dtype = dtype ,
932+ device = device ,
933+ ).eval ()
934+
935+ quantized_model = copy .deepcopy (model )
936+
937+ config = Float8DynamicActivationFloat8WeightConfig (
938+ granularity = granularity ,
939+ )
940+
941+ _is_conv = lambda m , fqn : isinstance (m , torch .nn .Conv2d )
900942
901- weight = linear . weight
943+ quantize_ ( quantized_model , config , filter_fn = _is_conv )
902944
903- # Test that unsqueezing on unsupported dimensions raises an error
904- with self .assertRaisesRegex (AssertionError , "Only dim == 0 is supported" ):
905- weight .unsqueeze (1 ) # dim=1 should not be supported
945+ original_weight = quantized_model .conv .weight
946+ original_shape = original_weight .shape
947+
948+ # Test unsqueeze operation at dim=2
949+ unsqueezed_weight = original_weight .unsqueeze (2 )
950+
951+ # Verify the unsqueezed tensor has correct shape
952+ original_shape_list = list (original_shape )
953+ expected_shape = original_shape_list [:2 ] + [1 ] + original_shape_list [2 :]
954+ scale_shape_list = list (original_weight .scale .shape )
955+ expected_scale_shape = scale_shape_list [:2 ] + [1 ] + scale_shape_list [2 :]
956+
957+ self .assertEqual (unsqueezed_weight .shape , torch .Size (expected_shape ))
958+ # Verify qdata and scale shapes
959+ expected_qdata_shape = expected_shape
960+
961+ self .assertEqual (
962+ unsqueezed_weight .qdata .shape , torch .Size (expected_qdata_shape )
963+ )
964+ self .assertEqual (
965+ unsqueezed_weight .scale .shape , torch .Size (expected_scale_shape )
966+ )
967+
968+ # Verify block_size is correctly updated
969+ expected_block_size = []
970+ for i in range (len (expected_shape )):
971+ expected_block_size .append (expected_shape [i ] // expected_scale_shape [i ])
972+
973+ self .assertEqual (unsqueezed_weight .block_size , expected_block_size )
974+
975+ # Test that metadata is preserved
976+ self .assertEqual (unsqueezed_weight .mm_config , original_weight .mm_config )
977+ self .assertEqual (
978+ unsqueezed_weight .act_quant_kwargs , original_weight .act_quant_kwargs
979+ )
980+ self .assertEqual (
981+ unsqueezed_weight .kernel_preference , original_weight .kernel_preference
982+ )
983+ self .assertEqual (unsqueezed_weight .dtype , original_weight .dtype )
984+
985+ # Test numerical correctness
986+ original_dequant = original_weight .dequantize ()
987+ unsqueezed_dequant = unsqueezed_weight .dequantize ()
988+ expected_dequant = original_dequant .unsqueeze (2 )
989+
990+ self .assertEqual (unsqueezed_dequant , expected_dequant )
906991
907992 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
908993 @common_utils .parametrize ("slice_dim" , [0 , 1 , 2 ])
0 commit comments