@@ -596,6 +596,70 @@ def test_udf_in_filter_on_top_of_join(self):
596
596
df = left .crossJoin (right ).filter (f ("a" , "b" ))
597
597
self .assertEqual (df .collect (), [Row (a = 1 , b = 1 )])
598
598
599
+ def test_udf_in_join_condition (self ):
600
+ # regression test for SPARK-25314
601
+ from pyspark .sql .functions import udf
602
+ left = self .spark .createDataFrame ([Row (a = 1 )])
603
+ right = self .spark .createDataFrame ([Row (b = 1 )])
604
+ f = udf (lambda a , b : a == b , BooleanType ())
605
+ df = left .join (right , f ("a" , "b" ))
606
+ with self .assertRaisesRegexp (AnalysisException , 'Detected implicit cartesian product' ):
607
+ df .collect ()
608
+ with self .sql_conf ({"spark.sql.crossJoin.enabled" : True }):
609
+ self .assertEqual (df .collect (), [Row (a = 1 , b = 1 )])
610
+
611
+ def test_udf_in_left_semi_join_condition (self ):
612
+ # regression test for SPARK-25314
613
+ from pyspark .sql .functions import udf
614
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
615
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 )])
616
+ f = udf (lambda a , b : a == b , BooleanType ())
617
+ df = left .join (right , f ("a" , "b" ), "leftsemi" )
618
+ with self .assertRaisesRegexp (AnalysisException , 'Detected implicit cartesian product' ):
619
+ df .collect ()
620
+ with self .sql_conf ({"spark.sql.crossJoin.enabled" : True }):
621
+ self .assertEqual (df .collect (), [Row (a = 1 , a1 = 1 , a2 = 1 )])
622
+
623
+ def test_udf_and_common_filter_in_join_condition (self ):
624
+ # regression test for SPARK-25314
625
+ # test the complex scenario with both udf and common filter
626
+ from pyspark .sql .functions import udf
627
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
628
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
629
+ f = udf (lambda a , b : a == b , BooleanType ())
630
+ df = left .join (right , [f ("a" , "b" ), left .a1 == right .b1 ])
631
+ # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
632
+ self .assertEqual (df .collect (), [Row (a = 1 , a1 = 1 , a2 = 1 , b = 1 , b1 = 1 , b2 = 1 )])
633
+
634
+ def test_udf_and_common_filter_in_left_semi_join_condition (self ):
635
+ # regression test for SPARK-25314
636
+ # test the complex scenario with both udf and common filter
637
+ from pyspark .sql .functions import udf
638
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
639
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
640
+ f = udf (lambda a , b : a == b , BooleanType ())
641
+ df = left .join (right , [f ("a" , "b" ), left .a1 == right .b1 ], "left_semi" )
642
+ # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
643
+ self .assertEqual (df .collect (), [Row (a = 1 , a1 = 1 , a2 = 1 )])
644
+
645
+ def test_udf_not_supported_in_join_condition (self ):
646
+ # regression test for SPARK-25314
647
+ # test python udf is not supported in join type besides left_semi and inner join.
648
+ from pyspark .sql .functions import udf
649
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
650
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
651
+ f = udf (lambda a , b : a == b , BooleanType ())
652
+
653
+ def runWithJoinType (join_type , type_string ):
654
+ with self .assertRaisesRegexp (
655
+ AnalysisException ,
656
+ 'Using PythonUDF.*%s is not supported.' % type_string ):
657
+ left .join (right , [f ("a" , "b" ), left .a1 == right .b1 ], join_type ).collect ()
658
+ runWithJoinType ("full" , "FullOuter" )
659
+ runWithJoinType ("left" , "LeftOuter" )
660
+ runWithJoinType ("right" , "RightOuter" )
661
+ runWithJoinType ("leftanti" , "LeftAnti" )
662
+
599
663
def test_udf_without_arguments (self ):
600
664
self .spark .catalog .registerFunction ("foo" , lambda : "bar" )
601
665
[row ] = self .spark .sql ("SELECT foo()" ).collect ()
0 commit comments