Skip to content

Commit

Permalink
[SPARK-48845][SQL] GenericUDF catch exceptions from children
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This pr is trying to fix the syntax issues with GenericUDF since 3.5.0. The problem arose from DeferredObject currently passing a value instead of a function, which prevented users from catching exceptions in GenericUDF, resulting in semantic differences.

Here is an example case we encountered. Originally, the semantics were that udf_exception would throw an exception, while udf_catch_exception could catch the exception and return a null value. However, currently, any exception encountered by udf_exception will cause the program to fail.
```
select udf_catch_exception(udf_exception(col1)) from table
```

### Why are the changes needed?
For before Spark 3.5, we directly made the GenericUDF's DeferredObject lazy and evaluated the children in `function.evaluate(deferredObjects)`.
Now, we would run the children's code first. If an exception is thrown, we would make it lazy to GenericUDF's DeferredObject.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Newly added UT.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#47268 from jackylee-ch/generic_udf_catch_exception_from_child_func.

Lead-authored-by: jackylee-ch <lijunqing@baidu.com>
Co-authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
  • Loading branch information
jackylee-ch and yaooqinn committed Jul 12, 2024
1 parent d747853 commit 236d957
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ class HiveGenericUDFEvaluator(
override def returnType: DataType = inspectorToDataType(returnInspector)

def setArg(index: Int, arg: Any): Unit =
deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(arg)
deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => arg)

def setException(index: Int, exp: Throwable): Unit = {
deferredObjects(index).asInstanceOf[DeferredObjectAdapter].set(() => throw exp)
}

override def doEvaluate(): Any = unwrapper(function.evaluate(deferredObjects))
}
Expand All @@ -139,10 +143,10 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
extends DeferredObject with HiveInspectors {

private val wrapper = wrapperFor(oi, dataType)
private var func: Any = _
def set(func: Any): Unit = {
private var func: () => Any = _
def set(func: () => Any): Unit = {
this.func = func
}
override def prepare(i: Int): Unit = {}
override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
}
22 changes: 16 additions & 6 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,13 @@ private[hive] case class HiveGenericUDF(

override def eval(input: InternalRow): Any = {
children.zipWithIndex.foreach {
case (child, idx) => evaluator.setArg(idx, child.eval(input))
case (child, idx) =>
try {
evaluator.setArg(idx, child.eval(input))
} catch {
case t: Throwable =>
evaluator.setException(idx, t)
}
}
evaluator.evaluate()
}
Expand All @@ -157,10 +163,15 @@ private[hive] case class HiveGenericUDF(
val setValues = evals.zipWithIndex.map {
case (eval, i) =>
s"""
|if (${eval.isNull}) {
| $refEvaluator.setArg($i, null);
|} else {
| $refEvaluator.setArg($i, ${eval.value});
|try {
| ${eval.code}
| if (${eval.isNull}) {
| $refEvaluator.setArg($i, null);
| } else {
| $refEvaluator.setArg($i, ${eval.value});
| }
|} catch (Throwable t) {
| $refEvaluator.setException($i, t);
|}
|""".stripMargin
}
Expand All @@ -169,7 +180,6 @@ private[hive] case class HiveGenericUDF(
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${evals.map(_.code).mkString("\n")}
|${setValues.mkString("\n")}
|$resultType $resultTerm = ($resultType) $refEvaluator.evaluate();
|boolean ${ev.isNull} = $resultTerm == null;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.execution;

import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;

public class UDFCatchException extends GenericUDF {

@Override
public ObjectInspector initialize(ObjectInspector[] args) throws UDFArgumentException {
if (args.length != 1) {
throw new UDFArgumentException("Exactly one argument is expected.");
}
return PrimitiveObjectInspectorFactory.javaStringObjectInspector;
}

@Override
public Object evaluate(GenericUDF.DeferredObject[] args) {
if (args == null) {
return null;
}
try {
return args[0].get();
} catch (Exception e) {
return null;
}
}

@Override
public String getDisplayString(String[] children) {
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.hive.execution;

import org.apache.hadoop.hive.ql.exec.UDF;

public class UDFThrowException extends UDF {
public String evaluate(String data) {
return Integer.valueOf(data).toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.hadoop.io.{LongWritable, Writable}

import org.apache.spark.{SparkException, SparkFiles, TestUtils}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.{call_function, max}
Expand Down Expand Up @@ -801,6 +802,28 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}
}

test("SPARK-48845: GenericUDF catch exceptions from child UDFs") {
withTable("test_catch_exception") {
withUserDefinedFunction("udf_throw_exception" -> true, "udf_catch_exception" -> true) {
Seq("9", "9-1").toDF("a").write.saveAsTable("test_catch_exception")
sql("CREATE TEMPORARY FUNCTION udf_throw_exception AS " +
s"'${classOf[UDFThrowException].getName}'")
sql("CREATE TEMPORARY FUNCTION udf_catch_exception AS " +
s"'${classOf[UDFCatchException].getName}'")
Seq(
CodegenObjectFactoryMode.FALLBACK.toString,
CodegenObjectFactoryMode.NO_CODEGEN.toString
).foreach { codegenMode =>
withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
val df = sql(
"SELECT udf_catch_exception(udf_throw_exception(a)) FROM test_catch_exception")
checkAnswer(df, Seq(Row("9"), Row(null)))
}
}
}
}
}
}

class TestPair(x: Int, y: Int) extends Writable with Serializable {
Expand Down

0 comments on commit 236d957

Please sign in to comment.