@@ -552,6 +552,70 @@ def test_udf_in_filter_on_top_of_join(self):
552
552
df = left .crossJoin (right ).filter (f ("a" , "b" ))
553
553
self .assertEqual (df .collect (), [Row (a = 1 , b = 1 )])
554
554
555
+ def test_udf_in_join_condition (self ):
556
+ # regression test for SPARK-25314
557
+ from pyspark .sql .functions import udf
558
+ left = self .spark .createDataFrame ([Row (a = 1 )])
559
+ right = self .spark .createDataFrame ([Row (b = 1 )])
560
+ f = udf (lambda a , b : a == b , BooleanType ())
561
+ df = left .join (right , f ("a" , "b" ))
562
+ with self .assertRaisesRegexp (AnalysisException , 'Detected implicit cartesian product' ):
563
+ df .collect ()
564
+ with self .sql_conf ({"spark.sql.crossJoin.enabled" : True }):
565
+ self .assertEqual (df .collect (), [Row (a = 1 , b = 1 )])
566
+
567
+ def test_udf_in_left_semi_join_condition (self ):
568
+ # regression test for SPARK-25314
569
+ from pyspark .sql .functions import udf
570
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
571
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 )])
572
+ f = udf (lambda a , b : a == b , BooleanType ())
573
+ df = left .join (right , f ("a" , "b" ), "leftsemi" )
574
+ with self .assertRaisesRegexp (AnalysisException , 'Detected implicit cartesian product' ):
575
+ df .collect ()
576
+ with self .sql_conf ({"spark.sql.crossJoin.enabled" : True }):
577
+ self .assertEqual (df .collect (), [Row (a = 1 , a1 = 1 , a2 = 1 )])
578
+
579
+ def test_udf_and_common_filter_in_join_condition (self ):
580
+ # regression test for SPARK-25314
581
+ # test the complex scenario with both udf and common filter
582
+ from pyspark .sql .functions import udf
583
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
584
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
585
+ f = udf (lambda a , b : a == b , BooleanType ())
586
+ df = left .join (right , [f ("a" , "b" ), left .a1 == right .b1 ])
587
+ # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
588
+ self .assertEqual (df .collect (), [Row (a = 1 , a1 = 1 , a2 = 1 , b = 1 , b1 = 1 , b2 = 1 )])
589
+
590
+ def test_udf_and_common_filter_in_left_semi_join_condition (self ):
591
+ # regression test for SPARK-25314
592
+ # test the complex scenario with both udf and common filter
593
+ from pyspark .sql .functions import udf
594
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
595
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
596
+ f = udf (lambda a , b : a == b , BooleanType ())
597
+ df = left .join (right , [f ("a" , "b" ), left .a1 == right .b1 ], "left_semi" )
598
+ # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
599
+ self .assertEqual (df .collect (), [Row (a = 1 , a1 = 1 , a2 = 1 )])
600
+
601
+ def test_udf_not_supported_in_join_condition (self ):
602
+ # regression test for SPARK-25314
603
+ # test python udf is not supported in join type besides left_semi and inner join.
604
+ from pyspark .sql .functions import udf
605
+ left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
606
+ right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
607
+ f = udf (lambda a , b : a == b , BooleanType ())
608
+
609
+ def runWithJoinType (join_type , type_string ):
610
+ with self .assertRaisesRegexp (
611
+ AnalysisException ,
612
+ 'Using PythonUDF.*%s is not supported.' % type_string ):
613
+ left .join (right , [f ("a" , "b" ), left .a1 == right .b1 ], join_type ).collect ()
614
+ runWithJoinType ("full" , "FullOuter" )
615
+ runWithJoinType ("left" , "LeftOuter" )
616
+ runWithJoinType ("right" , "RightOuter" )
617
+ runWithJoinType ("leftanti" , "LeftAnti" )
618
+
555
619
def test_udf_without_arguments (self ):
556
620
self .spark .catalog .registerFunction ("foo" , lambda : "bar" )
557
621
[row ] = self .spark .sql ("SELECT foo()" ).collect ()
0 commit comments