Skip to content

Commit

Permalink
[SPARK-48834][SQL] Disable variant input/output to python scalar UDFs…
Browse files Browse the repository at this point in the history
…, UDTFs, UDAFs during query compilation

### What changes were proposed in this pull request?

Throws an exception if a variant is the input/output type to/from python UDF, UDAF, UDTF

### Why are the changes needed?

currently, variant input/output types to scalar UDFs will fail during execution or return a `net.razorvine.pickle.objects.ClassDictConstructor` to the user python code. For a better UX, we should fail during query compilation for failures, and block returning `ClassDictConstructor` to user code as we one day want to actually return `VariantVal`s to the user code.

### Does this PR introduce _any_ user-facing change?

yes - attempting to use variants in python UDFs will now throw an exception rather than returning a `ClassDictConstructor` as before. However, we want to make this change now as we one day want to be able to return `VariantVal`s to the user code and do not want users relying on this current behavior

### How was this patch tested?

added UTs

### Was this patch authored or co-authored using generative AI tooling?

no

Closes apache#47253 from richardc-db/variant_scalar_udfs.

Authored-by: Richard Chen <r.chen@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
richardc-db authored and HyukjinKwon committed Jul 15, 2024
1 parent 206cc1a commit effd4d8
Show file tree
Hide file tree
Showing 15 changed files with 582 additions and 18 deletions.
10 changes: 10 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,16 @@
"The input of <functionName> can't be <dataType> type data."
]
},
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
]
},
"UNSUPPORTED_UDF_OUTPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an output data type."
]
},
"VALUE_OUT_OF_RANGE" : {
"message" : [
"The <exprName> must be between <valueRange> (current value = <currentValue>)."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ object ArrowDeserializers {
}
}

case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
case (CalendarIntervalEncoder | VariantEncoder | _: UDTEncoder[_], _) =>
throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType)

case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ object ArrowSerializer {
o => getter.invoke(o)
}

case (CalendarIntervalEncoder | _: UDTEncoder[_], _) =>
case (CalendarIntervalEncoder | VariantEncoder | _: UDTEncoder[_], _) =>
throw ExecutionErrors.unsupportedDataTypeError(encoder.dataType)

case _ =>
Expand Down
45 changes: 45 additions & 0 deletions python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
DateType,
BinaryType,
YearMonthIntervalType,
VariantType,
VariantVal,
)
from pyspark.errors import AnalysisException, PythonException
from pyspark.testing.sqlutils import (
Expand Down Expand Up @@ -748,6 +750,49 @@ def check_vectorized_udf_return_scalar(self):
with self.assertRaisesRegex(Exception, "Return.*type.*Series"):
df.select(f(col("id"))).collect()

def test_udf_with_variant_input(self):
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")
from pyspark.sql.functions import col

scalar_f = pandas_udf(lambda u: str(u), StringType())
iter_f = pandas_udf(
lambda it: map(lambda u: str(u), it), StringType(), PandasUDFType.SCALAR_ITER
)

for f in [scalar_f, iter_f]:
with self.assertRaises(AnalysisException) as ae:
df.select(f(col("v"))).collect()

self.check_error(
exception=ae.exception,
error_class="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
message_parameters={
"sqlExpr": '"<lambda>(v)"',
"dataType": "VARIANT",
},
)

def test_udf_with_variant_output(self):
# Corresponds to a JSON string of {"a": "b"}.
returned_variant = VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 1, 97]))
scalar_f = pandas_udf(lambda x: returned_variant, VariantType())
iter_f = pandas_udf(
lambda it: map(lambda x: returned_variant, it), VariantType(), PandasUDFType.SCALAR_ITER
)

for f in [scalar_f, iter_f]:
with self.assertRaises(AnalysisException) as ae:
self.spark.range(0, 10).select(f()).collect()

self.check_error(
exception=ae.exception,
error_class="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
message_parameters={
"sqlExpr": '"<lambda>()"',
"dataType": "VARIANT",
},
)

def test_vectorized_udf_decorator(self):
df = self.spark.range(10)

Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pyspark.sql import functions as F
from pyspark.errors import (
AnalysisException,
ParseException,
PySparkTypeError,
PySparkValueError,
PySparkRuntimeError,
Expand Down Expand Up @@ -2216,6 +2217,44 @@ def test_from_ddl(self):
StructType([StructField("a", IntegerType()), StructField("v", VariantType())]),
)

# Ensures that changing the implementation of `DataType.fromDDL` in PR #47253 does not change
# `fromDDL`'s behavior.
def test_spark48834_from_ddl_matches_udf_schema_string(self):
from pyspark.sql.functions import udf

def schema_from_udf(ddl):
schema = (
self.spark.active().range(0).select(udf(lambda x: x, returnType=ddl)("id")).schema
)
assert len(schema) == 1
return schema[0].dataType

tests = [
("a:int, b:string", True),
(
"a struct<>, b map<int, binary>, "
+ "c array<array<map<struct<a: int, b: int>, binary>>>",
True,
),
("struct<>", True),
("struct<a: string, b: array<long>>", True),
("", True),
("<a: int, b: variant>", False),
("randomstring", False),
("struct", False),
]
for test, is_valid_input in tests:
if is_valid_input:
self.assertEqual(DataType.fromDDL(test), schema_from_udf(test))
else:
with self.assertRaises(ParseException) as from_ddl_pe:
DataType.fromDDL(test)
with self.assertRaises(ParseException) as udf_pe:
schema_from_udf(test)
self.assertEqual(
from_ddl_pe.exception.getErrorClass(), udf_pe.exception.getErrorClass()
)

def test_collated_string(self):
dfs = [
self.spark.sql("SELECT 'abc' collate UTF8_LCASE"),
Expand Down
68 changes: 67 additions & 1 deletion python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
DoubleType,
LongType,
ArrayType,
MapType,
StructType,
StructField,
TimestampNTZType,
DayTimeIntervalType,
VariantType,
VariantVal,
)
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
from pyspark.testing.sqlutils import (
Expand Down Expand Up @@ -324,12 +327,75 @@ def test_broadcast_in_udf(self):

def test_udf_with_filter_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import col

my_filter = udf(lambda a: a < 2, BooleanType())
sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
self.assertEqual(sel.collect(), [Row(key=1, value="1")])

def test_udf_with_variant_input(self):
df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as string)) v")

u = udf(lambda u: str(u), StringType())
with self.assertRaises(AnalysisException) as ae:
df.select(u(col("v"))).collect()

self.check_error(
exception=ae.exception,
error_class="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
message_parameters={"sqlExpr": '"<lambda>(v)"', "dataType": "VARIANT"},
)

def test_udf_with_complex_variant_input(self):
df = self.spark.range(0, 10).selectExpr(
"named_struct('v', parse_json(cast(id as string))) struct_of_v"
)

u = udf(lambda u: str(u), StringType())

with self.assertRaises(AnalysisException) as ae:
df.select(u(col("struct_of_v"))).collect()

self.check_error(
exception=ae.exception,
error_class="DATATYPE_MISMATCH.UNSUPPORTED_UDF_INPUT_TYPE",
message_parameters={
"sqlExpr": '"<lambda>(struct_of_v)"',
"dataType": "STRUCT<v: VARIANT NOT NULL>",
},
)

def test_udf_with_variant_output(self):
# The variant value returned corresponds to a JSON string of {"a": "b"}.
u = udf(
lambda: VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 1, 97])),
VariantType(),
)

with self.assertRaises(AnalysisException) as ae:
self.spark.range(0, 10).select(u()).collect()

self.check_error(
exception=ae.exception,
error_class="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
message_parameters={"sqlExpr": '"<lambda>()"', "dataType": "VARIANT"},
)

def test_udf_with_complex_variant_output(self):
# The variant value returned corresponds to a JSON string of {"a": "b"}.
u = udf(
lambda: {"v", VariantVal(bytes([2, 1, 0, 0, 2, 5, 98]), bytes([1, 1, 0, 1, 97]))},
MapType(StringType(), VariantType()),
)

with self.assertRaises(AnalysisException) as ae:
self.spark.range(0, 10).select(u()).collect()

self.check_error(
exception=ae.exception,
error_class="DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
message_parameters={"sqlExpr": '"<lambda>()"', "dataType": "MAP<STRING, VARIANT>"},
)

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import col, sum
Expand Down
17 changes: 7 additions & 10 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,7 @@ def fromDDL(cls, ddl: str) -> "DataType":
>>> DataType.fromDDL("b: string, a: int")
StructType([StructField('b', StringType(), True), StructField('a', IntegerType(), True)])
"""
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf

# Intentionally uses SparkSession so one implementation can be shared with/without
# Spark Connect.
schema = (
SparkSession.active().range(0).select(udf(lambda x: x, returnType=ddl)("id")).schema
)
assert len(schema) == 1
return schema[0].dataType
return _parse_datatype_string(ddl)

@classmethod
def _data_type_build_formatted_string(
Expand Down Expand Up @@ -1578,6 +1569,12 @@ def fromInternal(self, obj: Dict) -> Optional["VariantVal"]:
return None
return VariantVal(obj["value"], obj["metadata"])

def toInternal(self, obj: Any) -> Any:
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": "VariantType.toInternal"},
)


class UserDefinedType(DataType):
"""User-defined type (UDT).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal}

private[catalyst] object ScalaSubtypeLock

Expand Down Expand Up @@ -322,6 +322,7 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => STRICT_TIMESTAMP_ENCODER
case t if isSubtype(t, localTypeOf[java.time.Instant]) => STRICT_INSTANT_ENCODER
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder
case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder
case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder

// UDT encoders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkException.internalError
import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils
import org.apache.spark.sql.catalyst.trees.TreePattern.{PYTHON_UDF, TreePattern}
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types._

/**
* Helper functions for [[PythonUDF]]
Expand Down Expand Up @@ -63,6 +64,23 @@ trait PythonFuncExpression extends NonSQLExpression with UserDefinedExpression {
override def toString: String = s"$name(${children.mkString(", ")})#${resultId.id}$typeSuffix"

override def nullable: Boolean = true

override def checkInputDataTypes(): TypeCheckResult = {
val check = super.checkInputDataTypes()
if (check.isFailure) {
check
} else {
val exprReturningVariant = children.collectFirst {
case e: Expression if VariantExpressionEvalUtils.typeContainsVariant(e.dataType) => e
}
exprReturningVariant match {
case Some(e) => TypeCheckResult.DataTypeMismatch(
errorSubClass = "UNSUPPORTED_UDF_INPUT_TYPE",
messageParameters = Map("dataType" -> s"${e.dataType.sql}"))
case None => TypeCheckResult.TypeCheckSuccess
}
}
}
}

/**
Expand All @@ -79,6 +97,10 @@ case class PythonUDF(
resultId: ExprId = NamedExpression.newExprId)
extends Expression with PythonFuncExpression with Unevaluable {

if (VariantExpressionEvalUtils.typeContainsVariant(dataType)) {
throw QueryCompilationErrors.unsupportedUDFOuptutType(this, dataType)
}

lazy val resultAttribute: Attribute = AttributeReference(toPrettySQL(this), dataType, nullable)(
exprId = resultId)

Expand Down Expand Up @@ -121,6 +143,10 @@ case class PythonUDAF(
resultId: ExprId = NamedExpression.newExprId)
extends UnevaluableAggregateFunc with PythonFuncExpression {

if (VariantExpressionEvalUtils.typeContainsVariant(dataType)) {
throw QueryCompilationErrors.unsupportedUDFOuptutType(this, dataType)
}

override def evalType: Int = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF

override def sql(isDistinct: Boolean): String = {
Expand Down Expand Up @@ -187,6 +213,13 @@ case class PythonUDTF(
pythonUDTFPartitionColumnIndexes: Option[PythonUDTFPartitionColumnIndexes] = None)
extends UnevaluableGenerator with PythonFuncExpression {

elementSchema.collectFirst {
case sf: StructField if VariantExpressionEvalUtils.typeContainsVariant(sf.dataType) => sf
} match {
case Some(sf) => throw QueryCompilationErrors.unsupportedUDFOuptutType(this, sf.dataType)
case None =>
}

override lazy val canonicalized: Expression = {
val canonicalizedChildren = children.map(_.canonicalized)
// `resultId` can be seen as cosmetic variation in PythonUDTF, as it doesn't affect the result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ object VariantExpressionEvalUtils {
new VariantVal(v.getValue, v.getMetadata)
}

/** Returns `true` if a data type is or has a child variant type. */
def typeContainsVariant(dt: DataType): Boolean = dt match {
case _: VariantType => true
case st: StructType => st.fields.exists(f => typeContainsVariant(f.dataType))
case at: ArrayType => typeContainsVariant(at.elementType)
// Variants cannot be map keys.
case mt: MapType => typeContainsVariant(mt.valueType)
case _ => false
}

private def buildVariant(builder: VariantBuilder, input: Any, dataType: DataType): Unit = {
if (input == null) {
builder.appendNull()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3825,6 +3825,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
"class" -> unsupported.getClass.toString))
}

def unsupportedUDFOuptutType(expr: Expression, dt: DataType): Throwable = {
new AnalysisException(
errorClass = "DATATYPE_MISMATCH.UNSUPPORTED_UDF_OUTPUT_TYPE",
messageParameters = Map("sqlExpr" -> toSQLExpr(expr), "dataType" -> dt.sql))
}

def funcBuildError(funcName: String, cause: Exception): Throwable = {
cause.getCause match {
case st: SparkThrowable with Throwable => st
Expand Down
Loading

0 comments on commit effd4d8

Please sign in to comment.