2323
2424from pyspark import SparkContext
2525from pyspark .sql import SparkSession , Column , Row
26- from pyspark .sql .functions import UserDefinedFunction
26+ from pyspark .sql .functions import UserDefinedFunction , udf
2727from pyspark .sql .types import *
2828from pyspark .sql .utils import AnalysisException
2929from pyspark .testing .sqlutils import ReusedSQLTestCase , test_compiled , test_not_compiled_message
@@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self):
102102
103103 def test_nondeterministic_udf (self ):
104104 # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
105- from pyspark .sql .functions import udf
106105 import random
107106 udf_random_col = udf (lambda : int (100 * random .random ()), IntegerType ()).asNondeterministic ()
108107 self .assertEqual (udf_random_col .deterministic , False )
@@ -113,7 +112,6 @@ def test_nondeterministic_udf(self):
113112
114113 def test_nondeterministic_udf2 (self ):
115114 import random
116- from pyspark .sql .functions import udf
117115 random_udf = udf (lambda : random .randint (6 , 6 ), IntegerType ()).asNondeterministic ()
118116 self .assertEqual (random_udf .deterministic , False )
119117 random_udf1 = self .spark .catalog .registerFunction ("randInt" , random_udf )
@@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self):
132130
133131 def test_nondeterministic_udf3 (self ):
134132 # regression test for SPARK-23233
135- from pyspark .sql .functions import udf
136133 f = udf (lambda x : x )
137134 # Here we cache the JVM UDF instance.
138135 self .spark .range (1 ).select (f ("id" ))
@@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self):
144141 self .assertFalse (deterministic )
145142
146143 def test_nondeterministic_udf_in_aggregate (self ):
147- from pyspark .sql .functions import udf , sum
144+ from pyspark .sql .functions import sum
148145 import random
149146 udf_random_col = udf (lambda : int (100 * random .random ()), 'int' ).asNondeterministic ()
150147 df = self .spark .range (10 )
@@ -181,7 +178,6 @@ def test_multiple_udfs(self):
181178 self .assertEqual (tuple (row ), (6 , 5 ))
182179
183180 def test_udf_in_filter_on_top_of_outer_join (self ):
184- from pyspark .sql .functions import udf
185181 left = self .spark .createDataFrame ([Row (a = 1 )])
186182 right = self .spark .createDataFrame ([Row (a = 1 )])
187183 df = left .join (right , on = 'a' , how = 'left_outer' )
@@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self):
190186
191187 def test_udf_in_filter_on_top_of_join (self ):
192188 # regression test for SPARK-18589
193- from pyspark .sql .functions import udf
194189 left = self .spark .createDataFrame ([Row (a = 1 )])
195190 right = self .spark .createDataFrame ([Row (b = 1 )])
196191 f = udf (lambda a , b : a == b , BooleanType ())
@@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self):
199194
200195 def test_udf_in_join_condition (self ):
201196 # regression test for SPARK-25314
202- from pyspark .sql .functions import udf
203197 left = self .spark .createDataFrame ([Row (a = 1 )])
204198 right = self .spark .createDataFrame ([Row (b = 1 )])
205199 f = udf (lambda a , b : a == b , BooleanType ())
@@ -211,7 +205,7 @@ def test_udf_in_join_condition(self):
211205
212206 def test_udf_in_left_outer_join_condition (self ):
213207 # regression test for SPARK-26147
214- from pyspark .sql .functions import udf , col
208+ from pyspark .sql .functions import col
215209 left = self .spark .createDataFrame ([Row (a = 1 )])
216210 right = self .spark .createDataFrame ([Row (b = 1 )])
217211 f = udf (lambda a : str (a ), StringType ())
@@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self):
223217
224218 def test_udf_in_left_semi_join_condition (self ):
225219 # regression test for SPARK-25314
226- from pyspark .sql .functions import udf
227220 left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
228221 right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 )])
229222 f = udf (lambda a , b : a == b , BooleanType ())
@@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self):
236229 def test_udf_and_common_filter_in_join_condition (self ):
237230 # regression test for SPARK-25314
238231 # test the complex scenario with both udf and common filter
239- from pyspark .sql .functions import udf
240232 left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
241233 right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
242234 f = udf (lambda a , b : a == b , BooleanType ())
@@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self):
247239 def test_udf_and_common_filter_in_left_semi_join_condition (self ):
248240 # regression test for SPARK-25314
249241 # test the complex scenario with both udf and common filter
250- from pyspark .sql .functions import udf
251242 left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
252243 right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
253244 f = udf (lambda a , b : a == b , BooleanType ())
@@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self):
258249 def test_udf_not_supported_in_join_condition (self ):
259250 # regression test for SPARK-25314
260251 # test python udf is not supported in join type besides left_semi and inner join.
261- from pyspark .sql .functions import udf
262252 left = self .spark .createDataFrame ([Row (a = 1 , a1 = 1 , a2 = 1 ), Row (a = 2 , a1 = 2 , a2 = 2 )])
263253 right = self .spark .createDataFrame ([Row (b = 1 , b1 = 1 , b2 = 1 ), Row (b = 1 , b1 = 3 , b2 = 1 )])
264254 f = udf (lambda a , b : a == b , BooleanType ())
@@ -301,7 +291,7 @@ def test_broadcast_in_udf(self):
301291
302292 def test_udf_with_filter_function (self ):
303293 df = self .spark .createDataFrame ([(1 , "1" ), (2 , "2" ), (1 , "2" ), (1 , "2" )], ["key" , "value" ])
304- from pyspark .sql .functions import udf , col
294+ from pyspark .sql .functions import col
305295 from pyspark .sql .types import BooleanType
306296
307297 my_filter = udf (lambda a : a < 2 , BooleanType ())
@@ -310,7 +300,7 @@ def test_udf_with_filter_function(self):
310300
311301 def test_udf_with_aggregate_function (self ):
312302 df = self .spark .createDataFrame ([(1 , "1" ), (2 , "2" ), (1 , "2" ), (1 , "2" )], ["key" , "value" ])
313- from pyspark .sql .functions import udf , col , sum
303+ from pyspark .sql .functions import col , sum
314304 from pyspark .sql .types import BooleanType
315305
316306 my_filter = udf (lambda a : a == 1 , BooleanType ())
@@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self):
326316 self .assertEqual (sel .collect (), [Row (t = 4 ), Row (t = 3 )])
327317
328318 def test_udf_in_generate (self ):
329- from pyspark .sql .functions import udf , explode
319+ from pyspark .sql .functions import explode
330320 df = self .spark .range (5 )
331321 f = udf (lambda x : list (range (x )), ArrayType (LongType ()))
332322 row = df .select (explode (f (* df ))).groupBy ().sum ().first ()
@@ -353,7 +343,6 @@ def test_udf_in_generate(self):
353343 self .assertEqual (res [3 ][1 ], 1 )
354344
355345 def test_udf_with_order_by_and_limit (self ):
356- from pyspark .sql .functions import udf
357346 my_copy = udf (lambda x : x , IntegerType ())
358347 df = self .spark .range (10 ).orderBy ("id" )
359348 res = df .select (df .id , my_copy (df .id ).alias ("copy" )).limit (1 )
@@ -394,14 +383,14 @@ def test_non_existed_udaf(self):
394383 lambda : spark .udf .registerJavaUDAF ("udaf1" , "non_existed_udaf" ))
395384
396385 def test_udf_with_input_file_name (self ):
397- from pyspark .sql .functions import udf , input_file_name
386+ from pyspark .sql .functions import input_file_name
398387 sourceFile = udf (lambda path : path , StringType ())
399388 filePath = "python/test_support/sql/people1.json"
400389 row = self .spark .read .json (filePath ).select (sourceFile (input_file_name ())).first ()
401390 self .assertTrue (row [0 ].find ("people1.json" ) != - 1 )
402391
403392 def test_udf_with_input_file_name_for_hadooprdd (self ):
404- from pyspark .sql .functions import udf , input_file_name
393+ from pyspark .sql .functions import input_file_name
405394
406395 def filename (path ):
407396 return path
@@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self):
427416 # This is separate of UDFInitializationTests
428417 # to avoid context initialization
429418 # when udf is called
430-
431- from pyspark .sql .functions import UserDefinedFunction
432-
433419 f = UserDefinedFunction (lambda x : x , StringType ())
434420
435421 self .assertIsNone (
@@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self):
445431 )
446432
447433 def test_udf_with_string_return_type (self ):
448- from pyspark .sql .functions import UserDefinedFunction
449-
450434 add_one = UserDefinedFunction (lambda x : x + 1 , "integer" )
451435 make_pair = UserDefinedFunction (lambda x : (- x , x ), "struct<x:integer,y:integer>" )
452436 make_array = UserDefinedFunction (
@@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self):
460444 self .assertTupleEqual (expected , actual )
461445
462446 def test_udf_shouldnt_accept_noncallable_object (self ):
463- from pyspark .sql .functions import UserDefinedFunction
464-
465447 non_callable = None
466448 self .assertRaises (TypeError , UserDefinedFunction , non_callable , StringType ())
467449
468450 def test_udf_with_decorator (self ):
469- from pyspark .sql .functions import lit , udf
451+ from pyspark .sql .functions import lit
470452 from pyspark .sql .types import IntegerType , DoubleType
471453
472454 @udf (IntegerType ())
@@ -523,7 +505,6 @@ def as_double(x):
523505 )
524506
525507 def test_udf_wrapper (self ):
526- from pyspark .sql .functions import udf
527508 from pyspark .sql .types import IntegerType
528509
529510 def f (x ):
@@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self):
569550 # SPARK-24721
570551 @unittest .skipIf (not test_compiled , test_not_compiled_message )
571552 def test_datasource_with_udf (self ):
572- from pyspark .sql .functions import udf , lit , col
553+ from pyspark .sql .functions import lit , col
573554
574555 path = tempfile .mkdtemp ()
575556 shutil .rmtree (path )
@@ -609,8 +590,6 @@ def test_datasource_with_udf(self):
609590
610591 # SPARK-25591
611592 def test_same_accumulator_in_udfs (self ):
612- from pyspark .sql .functions import udf
613-
614593 data_schema = StructType ([StructField ("a" , IntegerType (), True ),
615594 StructField ("b" , IntegerType (), True )])
616595 data = self .spark .createDataFrame ([[1 , 2 ]], schema = data_schema )
@@ -632,6 +611,15 @@ def second_udf(x):
632611 data .collect ()
633612 self .assertEqual (test_accum .value , 101 )
634613
614+ # SPARK-26293
615+ def test_udf_in_subquery (self ):
616+ f = udf (lambda x : x , "long" )
617+ with self .tempView ("v" ):
618+ self .spark .range (1 ).filter (f ("id" ) >= 0 ).createTempView ("v" )
619+ sql = self .spark .sql
620+ result = sql ("select i from values(0L) as data(i) where i in (select id from v)" )
621+ self .assertEqual (result .collect (), [Row (i = 0 )])
622+
635623
636624class UDFInitializationTests (unittest .TestCase ):
637625 def tearDown (self ):
@@ -642,8 +630,6 @@ def tearDown(self):
642630 SparkContext ._active_spark_context .stop ()
643631
644632 def test_udf_init_shouldnt_initialize_context (self ):
645- from pyspark .sql .functions import UserDefinedFunction
646-
647633 UserDefinedFunction (lambda x : x , StringType ())
648634
649635 self .assertIsNone (
0 commit comments