@@ -86,6 +86,8 @@ def __init__(
8686 )
8787 if dim == 3 :
8888 self .conv = self .conv .to (memory_format = torch .channels_last_3d )
89+ elif dim == 2 :
90+ self .conv = self .conv .to (memory_format = torch .channels_last )
8991
9092 def forward (self , x ):
9193 return self .conv (x )
@@ -336,33 +338,43 @@ def _test_fp8_matmul_model(
336338 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
337339 @common_utils .parametrize ("compile" , [True , False ])
338340 @common_utils .parametrize ("inference_mode" , [True , False ])
339- # only test for 3D conv for now
340- # Inputs are (N, C_in, C_out, D, H, W)
341+ # test for 2D/3D conv
342+ # Inputs are (N, C_in, C_out, (D, H, W) or
343+ # (N, C_in, C_out, (H, W)
341344 @common_utils .parametrize (
342345 "sizes" ,
343346 [
344- (4 , 16 , 64 , 32 , 32 , 32 ),
347+ (4 , 16 , 64 , (32 , 32 , 32 )),
348+ (4 , 16 , 64 , (32 , 32 )),
345349 ],
346350 )
347351 def test_fp8_conv_variants (
348352 self ,
349353 dtype : torch .dtype ,
350354 compile : bool ,
351355 inference_mode : bool ,
352- kernel_preference : KernelPreference ,
353356 sizes : Tuple ,
354357 ):
358+ torch .compiler .reset ()
355359 granularity = PerTensor ()
356360 kernel_preference = KernelPreference .AUTO
357- N , C_in , C_out , D , H , W = sizes
358- dim = 3
361+
362+ N , C_in , C_out , spatial_dims = sizes
363+ dim = len (spatial_dims )
364+ convs = {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
365+ assert dim in convs , f"Unsupported dim: { dim } "
366+ conv_class = convs [dim ]
367+
359368 kernel_size = 3
360369
361370 # Note: this is channel last memory format
362- input_tensor = torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
363- input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
371+ input_tensor = torch .randn (N , C_in , * spatial_dims , dtype = dtype , device = "cuda" )
372+ if dim == 3 :
373+ input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
374+ else :
375+ assert dim == 2
376+ input_tensor = input_tensor .to (memory_format = torch .channels_last )
364377
365- # Create a linear layer with bfloat16 dtype
366378 model = ToyConvModel (
367379 dim ,
368380 C_in ,
@@ -381,9 +393,9 @@ def test_fp8_conv_variants(
381393 kernel_preference = kernel_preference ,
382394 )
383395
384- _is_conv3d = lambda m , fqn : isinstance (m , torch . nn . Conv3d )
396+ _is_conv = lambda m , fqn : isinstance (m , conv_class )
385397
386- quantize_ (quantized_model , config , filter_fn = _is_conv3d )
398+ quantize_ (quantized_model , config , filter_fn = _is_conv )
387399
388400 if compile :
389401 quantized_model = torch .compile (quantized_model , fullgraph = True )
@@ -407,13 +419,16 @@ def test_fp8_conv_variants(
407419 "Requires fbgemm_gpu_genai to be installed" ,
408420 )
409421 @common_utils .parametrize ("dtype" , [torch .bfloat16 , torch .float32 ])
410- # only test for 3D conv for now
411- # Inputs are (N, C_in, C_out, D, H, W)
422+ # test for 2D/3D conv
423+ # Inputs are (N, C_in, C_out, (D, H, W) or
424+ # (N, C_in, C_out, (H, W)
412425 @common_utils .parametrize (
413426 "sizes" ,
414427 [
415- (4 , 12 , 64 , 32 , 32 , 32 ),
416- (4 , 16 , 12 , 32 , 32 , 32 ),
428+ (4 , 12 , 64 , (32 , 32 , 32 )),
429+ (4 , 16 , 12 , (32 , 32 , 32 )),
430+ (4 , 12 , 64 , (32 , 32 )),
431+ (4 , 16 , 12 , (32 , 32 )),
417432 ],
418433 )
419434 def test_fp8_conv_skip_quant (
@@ -426,14 +441,23 @@ def test_fp8_conv_skip_quant(
426441 """
427442 granularity = PerTensor ()
428443 kernel_preference = KernelPreference .AUTO
429- N , C_in , C_out , D , H , W = sizes
430- dim = 3
444+
445+ N , C_in , C_out , spatial_dims = sizes
446+
447+ dim = len (spatial_dims )
448+ convs = {1 : torch .nn .Conv1d , 2 : torch .nn .Conv2d , 3 : torch .nn .Conv3d }
449+ assert dim in convs , f"Unsupported dim: { dim } "
450+ conv_class = convs [dim ]
451+
431452 kernel_size = 3
432453
433454 # Note: this is channel last memory format
434- input_tensor = torch .randn (N , C_in , D , H , W , dtype = dtype , device = "cuda" )
435- input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
436- # Create a linear layer with bfloat16 dtype
455+ input_tensor = torch .randn (N , C_in , * spatial_dims , dtype = dtype , device = "cuda" )
456+ if dim == 3 :
457+ input_tensor = input_tensor .to (memory_format = torch .channels_last_3d )
458+ else :
459+ input_tensor = input_tensor .to (memory_format = torch .channels_last )
460+
437461 model = ToyConvModel (
438462 dim ,
439463 C_in ,
@@ -452,9 +476,9 @@ def test_fp8_conv_skip_quant(
452476 kernel_preference = kernel_preference ,
453477 )
454478
455- _is_conv3d = lambda m , fqn : isinstance (m , torch . nn . Conv3d )
479+ _is_conv = lambda m , fqn : isinstance (m , conv_class )
456480
457- quantize_ (quantized_model , config , filter_fn = _is_conv3d )
481+ quantize_ (quantized_model , config , filter_fn = _is_conv )
458482 assert not isinstance (quantized_model .conv .weight , Float8Tensor )
459483
460484 output_original = model (input_tensor )
@@ -793,7 +817,6 @@ def test_index_select(self):
793817 ],
794818 )
795819 def test_unsqueeze_operation (self , granularity , sizes ):
796- """Test aten.unsqueeze.default operation on Float8Tensor"""
797820 config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
798821 dtype = torch .bfloat16
799822 device = "cuda"
@@ -806,7 +829,7 @@ def test_unsqueeze_operation(self, granularity, sizes):
806829 original_weight = linear .weight
807830 original_shape = original_weight .shape
808831
809- # Test unsqueeze operation at dim=0 (only supported dimension)
832+ # Test unsqueeze operation at dim=0
810833 unsqueezed_weight = original_weight .unsqueeze (0 )
811834
812835 # Verify the unsqueezed tensor has correct shape
@@ -848,6 +871,85 @@ def test_unsqueeze_operation(self, granularity, sizes):
848871
849872 self .assertEqual (unsqueezed_dequant , expected_dequant )
850873
874+ def test_unsqueeze_conv2d_weight (self ):
875+ granularity = PerTensor ()
876+ config = Float8DynamicActivationFloat8WeightConfig (granularity = granularity )
877+ dtype = torch .bfloat16
878+ device = "cuda"
879+ N , C_in , C_out , spatial_dims = 4 , 16 , 64 , (32 , 32 )
880+ dim = len (spatial_dims )
881+ kernel_size = 3
882+
883+ input_tensor = torch .randn (N , C_in , * spatial_dims , dtype = dtype , device = device )
884+ input_tensor = input_tensor .to (memory_format = torch .channels_last )
885+ model = ToyConvModel (
886+ dim ,
887+ C_in ,
888+ C_out ,
889+ kernel_size ,
890+ bias = False ,
891+ padding = 0 ,
892+ dtype = dtype ,
893+ device = device ,
894+ ).eval ()
895+
896+ quantized_model = copy .deepcopy (model )
897+
898+ config = Float8DynamicActivationFloat8WeightConfig (
899+ granularity = granularity ,
900+ )
901+
902+ _is_conv = lambda m , fqn : isinstance (m , torch .nn .Conv2d )
903+
904+ quantize_ (quantized_model , config , filter_fn = _is_conv )
905+
906+ original_weight = quantized_model .conv .weight
907+ original_shape = original_weight .shape
908+
909+ # Test unsqueeze operation at dim=2
910+ unsqueezed_weight = original_weight .unsqueeze (2 )
911+
912+ # Verify the unsqueezed tensor has correct shape
913+ original_shape_list = list (original_shape )
914+ expected_shape = original_shape_list [:2 ] + [1 ] + original_shape_list [2 :]
915+ scale_shape_list = list (original_weight .scale .shape )
916+ expected_scale_shape = scale_shape_list [:2 ] + [1 ] + scale_shape_list [2 :]
917+
918+ self .assertEqual (unsqueezed_weight .shape , torch .Size (expected_shape ))
919+ # Verify qdata and scale shapes
920+ expected_qdata_shape = expected_shape
921+
922+ self .assertEqual (
923+ unsqueezed_weight .qdata .shape , torch .Size (expected_qdata_shape )
924+ )
925+ self .assertEqual (
926+ unsqueezed_weight .scale .shape , torch .Size (expected_scale_shape )
927+ )
928+
929+ # Verify block_size is correctly updated
930+ expected_block_size = []
931+ for i in range (len (expected_shape )):
932+ expected_block_size .append (expected_shape [i ] // expected_scale_shape [i ])
933+
934+ self .assertEqual (unsqueezed_weight .block_size , expected_block_size )
935+
936+ # Test that metadata is preserved
937+ self .assertEqual (unsqueezed_weight .mm_config , original_weight .mm_config )
938+ self .assertEqual (
939+ unsqueezed_weight .act_quant_kwargs , original_weight .act_quant_kwargs
940+ )
941+ self .assertEqual (
942+ unsqueezed_weight .kernel_preference , original_weight .kernel_preference
943+ )
944+ self .assertEqual (unsqueezed_weight .dtype , original_weight .dtype )
945+
946+ # Test numerical correctness
947+ original_dequant = original_weight .dequantize ()
948+ unsqueezed_dequant = unsqueezed_weight .dequantize ()
949+ expected_dequant = original_dequant .unsqueeze (2 )
950+
951+ self .assertEqual (unsqueezed_dequant , expected_dequant )
952+
851953 @common_utils .parametrize ("granularity" , [PerTensor (), PerRow ()])
852954 def test_unsqueeze_error_cases (self , granularity ):
853955 """Test error cases for aten.unsqueeze.default operation"""
0 commit comments