Skip to content

Commit e576c1e

Browse files
HyukjinKwongatorsmile
authored andcommitted
[SPARK-9435][SQL] Reuse function in Java UDF to correctly support expressions that require equality comparison between ScalaUDF
## What changes were proposed in this pull request? Currently, running the codes in Java ```java spark.udf().register("inc", new UDF1<Long, Long>() { Override public Long call(Long i) { return i + 1; } }, DataTypes.LongType); spark.range(10).toDF("x").createOrReplaceTempView("tmp"); Row result = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").head(); Assert.assertEquals(7, result.getLong(0)); ``` fails as below: ``` org.apache.spark.sql.AnalysisException: expression 'tmp.`x`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.;; Aggregate [UDF(x#19L)], [UDF(x#19L) AS UDF(x)#23L] +- SubqueryAlias tmp, `tmp` +- Project [id#16L AS x#19L] +- Range (0, 10, step=1, splits=Some(8)) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.failAnalysis(CheckAnalysis.scala:40) at org.apache.spark.sql.catalyst.analysis.Analyzer.failAnalysis(Analyzer.scala:57) ``` The root cause is because we were creating the function every time when it needs to build as below: ```scala scala> def inc(i: Int) = i + 1 inc: (i: Int)Int scala> (inc(_: Int)).hashCode res15: Int = 1231799381 scala> (inc(_: Int)).hashCode res16: Int = 2109839984 scala> (inc(_: Int)) == (inc(_: Int)) res17: Boolean = false ``` This seems leading to the comparison failure between `ScalaUDF`s created from Java UDF API, for example, in `Expression.semanticEquals`. In case of Scala one, it seems already fine. Both can be tested easily as below if any reviewer is more comfortable with Scala: ```scala val df = Seq((1, 10), (2, 11), (3, 12)).toDF("x", "y") val javaUDF = new UDF1[Int, Int] { override def call(i: Int): Int = i + 1 } // spark.udf.register("inc", javaUDF, IntegerType) // Uncomment this for Java API // spark.udf.register("inc", (i: Int) => i + 1) // Uncomment this for Scala API df.createOrReplaceTempView("tmp") spark.sql("SELECT inc(y) FROM tmp GROUP BY inc(y)").show() ``` ## How was this patch tested? Unit test in `JavaUDFSuite.java` and `./dev/lint-java`. Author: hyukjinkwon <gurwls223@gmail.com> Closes #16553 from HyukjinKwon/SPARK-9435.
1 parent 3bdf3ee commit e576c1e

File tree

2 files changed

+68
-23
lines changed

2 files changed

+68
-23
lines changed

sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
109109
| * @since 1.3.0
110110
| */
111111
|def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = {
112+
| val func = f$anyCast.call($anyParams)
112113
| functionRegistry.registerFunction(
113114
| name,
114-
| (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e))
115+
| (e: Seq[Expression]) => ScalaUDF(func, returnType, e))
115116
|}""".stripMargin)
116117
}
117118
*/
@@ -488,219 +489,241 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
488489
* @since 1.3.0
489490
*/
490491
def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = {
492+
val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
491493
functionRegistry.registerFunction(
492494
name,
493-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e))
495+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
494496
}
495497

496498
/**
497499
* Register a user-defined function with 2 arguments.
498500
* @since 1.3.0
499501
*/
500502
def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = {
503+
val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
501504
functionRegistry.registerFunction(
502505
name,
503-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e))
506+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
504507
}
505508

506509
/**
507510
* Register a user-defined function with 3 arguments.
508511
* @since 1.3.0
509512
*/
510513
def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = {
514+
val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any)
511515
functionRegistry.registerFunction(
512516
name,
513-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e))
517+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
514518
}
515519

516520
/**
517521
* Register a user-defined function with 4 arguments.
518522
* @since 1.3.0
519523
*/
520524
def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = {
525+
val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any)
521526
functionRegistry.registerFunction(
522527
name,
523-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e))
528+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
524529
}
525530

526531
/**
527532
* Register a user-defined function with 5 arguments.
528533
* @since 1.3.0
529534
*/
530535
def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = {
536+
val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any)
531537
functionRegistry.registerFunction(
532538
name,
533-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
539+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
534540
}
535541

536542
/**
537543
* Register a user-defined function with 6 arguments.
538544
* @since 1.3.0
539545
*/
540546
def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = {
547+
val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
541548
functionRegistry.registerFunction(
542549
name,
543-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
550+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
544551
}
545552

546553
/**
547554
* Register a user-defined function with 7 arguments.
548555
* @since 1.3.0
549556
*/
550557
def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = {
558+
val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
551559
functionRegistry.registerFunction(
552560
name,
553-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
561+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
554562
}
555563

556564
/**
557565
* Register a user-defined function with 8 arguments.
558566
* @since 1.3.0
559567
*/
560568
def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
569+
val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
561570
functionRegistry.registerFunction(
562571
name,
563-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
572+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
564573
}
565574

566575
/**
567576
* Register a user-defined function with 9 arguments.
568577
* @since 1.3.0
569578
*/
570579
def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
580+
val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
571581
functionRegistry.registerFunction(
572582
name,
573-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
583+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
574584
}
575585

576586
/**
577587
* Register a user-defined function with 10 arguments.
578588
* @since 1.3.0
579589
*/
580590
def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
591+
val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
581592
functionRegistry.registerFunction(
582593
name,
583-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
594+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
584595
}
585596

586597
/**
587598
* Register a user-defined function with 11 arguments.
588599
* @since 1.3.0
589600
*/
590601
def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
602+
val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
591603
functionRegistry.registerFunction(
592604
name,
593-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
605+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
594606
}
595607

596608
/**
597609
* Register a user-defined function with 12 arguments.
598610
* @since 1.3.0
599611
*/
600612
def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
613+
val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
601614
functionRegistry.registerFunction(
602615
name,
603-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
616+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
604617
}
605618

606619
/**
607620
* Register a user-defined function with 13 arguments.
608621
* @since 1.3.0
609622
*/
610623
def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
624+
val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
611625
functionRegistry.registerFunction(
612626
name,
613-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
627+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
614628
}
615629

616630
/**
617631
* Register a user-defined function with 14 arguments.
618632
* @since 1.3.0
619633
*/
620634
def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
635+
val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
621636
functionRegistry.registerFunction(
622637
name,
623-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
638+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
624639
}
625640

626641
/**
627642
* Register a user-defined function with 15 arguments.
628643
* @since 1.3.0
629644
*/
630645
def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
646+
val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
631647
functionRegistry.registerFunction(
632648
name,
633-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
649+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
634650
}
635651

636652
/**
637653
* Register a user-defined function with 16 arguments.
638654
* @since 1.3.0
639655
*/
640656
def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
657+
val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
641658
functionRegistry.registerFunction(
642659
name,
643-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
660+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
644661
}
645662

646663
/**
647664
* Register a user-defined function with 17 arguments.
648665
* @since 1.3.0
649666
*/
650667
def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
668+
val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
651669
functionRegistry.registerFunction(
652670
name,
653-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
671+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
654672
}
655673

656674
/**
657675
* Register a user-defined function with 18 arguments.
658676
* @since 1.3.0
659677
*/
660678
def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
679+
val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
661680
functionRegistry.registerFunction(
662681
name,
663-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
682+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
664683
}
665684

666685
/**
667686
* Register a user-defined function with 19 arguments.
668687
* @since 1.3.0
669688
*/
670689
def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
690+
val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
671691
functionRegistry.registerFunction(
672692
name,
673-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
693+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
674694
}
675695

676696
/**
677697
* Register a user-defined function with 20 arguments.
678698
* @since 1.3.0
679699
*/
680700
def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
701+
val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
681702
functionRegistry.registerFunction(
682703
name,
683-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
704+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
684705
}
685706

686707
/**
687708
* Register a user-defined function with 21 arguments.
688709
* @since 1.3.0
689710
*/
690711
def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
712+
val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
691713
functionRegistry.registerFunction(
692714
name,
693-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
715+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
694716
}
695717

696718
/**
697719
* Register a user-defined function with 22 arguments.
698720
* @since 1.3.0
699721
*/
700722
def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = {
723+
val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
701724
functionRegistry.registerFunction(
702725
name,
703-
(e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e))
726+
(e: Seq[Expression]) => ScalaUDF(func, returnType, e))
704727
}
705728

706729
// scalastyle:on line.size.limit

sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package test.org.apache.spark.sql;
1919

2020
import java.io.Serializable;
21+
import java.util.List;
2122

2223
import org.junit.After;
2324
import org.junit.Assert;
@@ -108,4 +109,25 @@ public void udf3Test() {
108109
result = spark.sql("SELECT stringLengthTest('test', 'test2')").head();
109110
Assert.assertEquals(9, result.getInt(0));
110111
}
112+
113+
@SuppressWarnings("unchecked")
114+
@Test
115+
public void udf4Test() {
116+
spark.udf().register("inc", new UDF1<Long, Long>() {
117+
@Override
118+
public Long call(Long i) {
119+
return i + 1;
120+
}
121+
}, DataTypes.LongType);
122+
123+
spark.range(10).toDF("x").createOrReplaceTempView("tmp");
124+
// This tests when Java UDFs are required to be the semantically same (See SPARK-9435).
125+
List<Row> results = spark.sql("SELECT inc(x) FROM tmp GROUP BY inc(x)").collectAsList();
126+
Assert.assertEquals(10, results.size());
127+
long sum = 0;
128+
for (Row result : results) {
129+
sum += result.getLong(0);
130+
}
131+
Assert.assertEquals(55, sum);
132+
}
111133
}

0 commit comments

Comments
 (0)