@@ -633,6 +633,92 @@ def test_conv2d_binary(self):
633633 def test_conv3d_binary (self ):
634634 self ._test_conv_binary_base (dim = 5 )
635635
636+ def _test_conv_binary_broadcast_shapes_base (self , dim = 4 ):
637+ assert dim == 4 or dim == 5
638+
639+ class M (torch .nn .Module ):
640+ def __init__ (
641+ self ,
642+ binary_fn ,
643+ has_relu ,
644+ ** kwargs ,
645+ ):
646+ super ().__init__ ()
647+ if dim == 4 :
648+ self .conv = torch .nn .Conv2d (3 , 16 , kernel_size = 3 , stride = 1 )
649+ else :
650+ self .conv = torch .nn .Conv3d (3 , 16 , kernel_size = 3 , stride = 1 )
651+ self .binary_fn = binary_fn
652+ self .has_relu = has_relu
653+
654+ def forward (self , x , x2 ):
655+ x1 = self .conv (x )
656+ if has_relu :
657+ return self .binary_fn (x1 , x2 ).relu ()
658+ else :
659+ return self .binary_fn (x1 , x2 )
660+
661+ dtypes = [
662+ torch .float ,
663+ ]
664+ if torch .ops .mkldnn ._is_mkldnn_bf16_supported ():
665+ dtypes .append (torch .bfloat16 )
666+ if torch .ops .mkldnn ._is_mkldnn_fp16_supported ():
667+ dtypes .append (torch .float16 )
668+ cl_format = torch .channels_last if dim == 4 else torch .channels_last_3d
669+ test_memory_format = [torch .contiguous_format , cl_format ]
670+ options = itertools .product (
671+ binary_list ,
672+ [True , False ],
673+ test_memory_format ,
674+ dtypes ,
675+ )
676+
677+ for (
678+ binary_fn ,
679+ has_relu ,
680+ memory_format ,
681+ dtype ,
682+ ) in options :
683+ metrics .reset ()
684+ if dim == 4 :
685+ x_shape = (1 , 3 , 56 , 56 )
686+ other_shape = (1 , 16 , 1 , 1 )
687+ else :
688+ x_shape = (1 , 3 , 20 , 56 , 56 )
689+ other_shape = (1 , 16 , 1 , 1 , 1 )
690+ mod = M (binary_fn , has_relu ).eval ()
691+ x = (
692+ torch .randn (x_shape , dtype = torch .float32 , requires_grad = True )
693+ .add (1 )
694+ .to (memory_format = memory_format )
695+ )
696+ other = (
697+ torch .randn (other_shape , dtype = torch .float32 , requires_grad = True )
698+ .add (1 )
699+ .to (memory_format = memory_format )
700+ .to (dtype )
701+ )
702+ match_count = binary_list [binary_fn ][0 ] + 1
703+ match_nodes = binary_list [binary_fn ][1 ]
704+ if has_relu :
705+ match_nodes += 1
706+ self ._test_common (
707+ mod , (x , other ), match_count , match_nodes + 1 , check_autocast = dtype
708+ )
709+
710+ @skipIfNoDynamoSupport
711+ @skipIfNoONEDNN
712+ @skipIfRocm
713+ def test_conv2d_binary_broadcast_shapes_cpu (self ):
714+ self ._test_conv_binary_broadcast_shapes_base (dim = 4 )
715+
716+ @skipIfNoDynamoSupport
717+ @skipIfNoONEDNN
718+ @skipIfRocm
719+ def test_conv3d_binary_broadcast_shapes_cpu (self ):
720+ self ._test_conv_binary_broadcast_shapes_base (dim = 5 )
721+
636722 def test_linear_binary (self ):
637723 class M (torch .nn .Module ):
638724 def __init__ (self , binary_fn , in_channels , out_channels , bias , ** kwargs ):
@@ -683,6 +769,55 @@ def forward(self, x, y):
683769 )
684770 self .assertEqual (metrics .generated_kernel_count , 1 )
685771
772+ def test_linear_binary_broadcast_shapes_cpu (self ):
773+ class M (torch .nn .Module ):
774+ def __init__ (self , binary_fn , in_channels , out_channels , bias , ** kwargs ):
775+ super ().__init__ ()
776+ self .linear = torch .nn .Linear (
777+ in_channels , out_channels , bias = bias , ** kwargs
778+ )
779+ self .binary_fn = binary_fn
780+
781+ def forward (self , x , y ):
782+ x = self .linear (x )
783+ x = self .binary_fn (x , y .clone ())
784+ return x
785+
786+ dtypes = []
787+ if torch .ops .mkldnn ._is_mkldnn_bf16_supported ():
788+ dtypes .append (torch .bfloat16 )
789+ if torch .ops .mkldnn ._is_mkldnn_fp16_supported ():
790+ dtypes .append (torch .float16 )
791+ options = itertools .product (
792+ binary_list , [[2 , 3 , 10 ], [2 , 10 ]], [True , False ], dtypes
793+ )
794+ out_feature = 30
795+
796+ for binary_fn , input_shape , bias , dtype in options :
797+ metrics .reset ()
798+ # addmm(mm) + (linear+add)
799+ match_count = 2
800+ match_nodes = 3
801+ if len (input_shape ) == 3 :
802+ is_inplace = binary_list [binary_fn ][2 ]
803+ # view + linear + view(joint_graph+freeze pass)
804+ match_count = match_count + 5 if is_inplace else match_count + 3
805+ match_nodes = match_nodes + 8 if is_inplace else match_nodes + 5
806+ mod = M (binary_fn , input_shape [- 1 ], out_feature , bias ).eval ()
807+ v = torch .randn (input_shape )
808+ other = torch .randn (input_shape [:- 1 ] + [1 ]).to (dtype )
809+ self ._test_common (
810+ mod ,
811+ (
812+ v ,
813+ other ,
814+ ),
815+ match_count ,
816+ match_nodes ,
817+ check_autocast = dtype ,
818+ )
819+ self .assertEqual (metrics .generated_kernel_count , 1 )
820+
686821 def test_multi_linear_share_same_input (self ):
687822 # llama pattern.
688823 class M (torch .nn .Module ):
0 commit comments