Skip to content

Commit 25002e0

Browse files
committed
[SPARK-35278][SQL] Invoke should find the method with correct number of parameters
### What changes were proposed in this pull request? This patch fixes `Invoke` expression when the target object has more than one method with the given method name. ### Why are the changes needed? `Invoke` will find out the method on the target object with given method name. If there are more than one method with the name, currently it is undeterministic which method will be used. We should add the condition of parameter number when finding the method. ### Does this PR introduce _any_ user-facing change? Yes, fixed a bug when using `Invoke` on a object where more than one method with the given method name. ### How was this patch tested? Unit test. Closes apache#32404 from viirya/verify-invoke-param-len. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com> (cherry picked from commit 6ce1b16) Signed-off-by: Liang-Chi Hsieh <viirya@gmail.com>
1 parent c245d84 commit 25002e0

File tree

2 files changed

+55
-5
lines changed

2 files changed

+55
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,30 @@ case class Invoke(
319319

320320
@transient lazy val method = targetObject.dataType match {
321321
case ObjectType(cls) =>
322-
val m = cls.getMethods.find(_.getName == encodedFunctionName)
323-
if (m.isEmpty) {
324-
sys.error(s"Couldn't find $encodedFunctionName on $cls")
325-
} else {
326-
m
322+
// Looking with function name + argument classes first.
323+
try {
324+
Some(cls.getMethod(encodedFunctionName, argClasses: _*))
325+
} catch {
326+
case _: NoSuchMethodException =>
327+
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
328+
// We look at function name + argument length
329+
val m = cls.getMethods.filter { m =>
330+
m.getName == encodedFunctionName && m.getParameterCount == arguments.length
331+
}
332+
if (m.isEmpty) {
333+
sys.error(s"Couldn't find $encodedFunctionName on $cls")
334+
} else if (m.length > 1) {
335+
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
336+
val realMethods = m.filter(!_.isSynthetic)
337+
if (realMethods.length > 1) {
338+
// Ambiguous case, we don't know which method to choose, just fail it.
339+
sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls")
340+
} else {
341+
Some(realMethods.head)
342+
}
343+
} else {
344+
Some(m.head)
345+
}
327346
}
328347
case _ => None
329348
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,29 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
618618
checkExceptionInExpression[ArithmeticException](
619619
StaticInvoke(mathCls, IntegerType, "addExact", Seq(Literal(Int.MaxValue), Literal(1))), "")
620620
}
621+
622+
test("SPARK-35278: invoke should find method with correct number of parameters") {
623+
val strClsType = ObjectType(classOf[String])
624+
checkExceptionInExpression[StringIndexOutOfBoundsException](
625+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(3))), "")
626+
627+
checkObjectExprEvaluation(
628+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0))), "a")
629+
630+
checkExceptionInExpression[StringIndexOutOfBoundsException](
631+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(3))), "")
632+
633+
checkObjectExprEvaluation(
634+
Invoke(Literal("a", strClsType), "substring", strClsType, Seq(Literal(0), Literal(1))), "a")
635+
}
636+
637+
test("SPARK-35278: invoke should correctly invoke override method") {
638+
val clsType = ObjectType(classOf[ConcreteClass])
639+
val obj = new ConcreteClass
640+
641+
checkObjectExprEvaluation(
642+
Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0)
643+
}
621644
}
622645

623646
class TestBean extends Serializable {
@@ -628,3 +651,11 @@ class TestBean extends Serializable {
628651
def setNonPrimitive(i: AnyRef): Unit =
629652
assert(i != null, "this setter should not be called with null.")
630653
}
654+
655+
abstract class BaseClass[T] {
656+
def testFunc(param: T): T
657+
}
658+
659+
class ConcreteClass extends BaseClass[Int] with Serializable {
660+
override def testFunc(param: Int): Int = param - 1
661+
}

0 commit comments

Comments
 (0)