Skip to content

[SPARK-13101][SQL][branch-1.6] nullability of array type element should not fail analysis of encoder #11042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ object JavaTypeInference {
val setter = if (nullable) {
constructor
} else {
AssertNotNull(constructor, other.getName, fieldName, fieldType.toString)
AssertNotNull(constructor, Seq("currently no type path record in java"))
}
p.getWriteMethod.getName -> setter
}.toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t

// TODO: add runtime null check for primitive array
val primitiveMethod = elementType match {
case t if t <:< definitions.IntTpe => Some("toIntArray")
case t if t <:< definitions.LongTpe => Some("toLongArray")
Expand Down Expand Up @@ -276,22 +278,29 @@ object ScalaReflection extends ScalaReflection {

case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val className = getClassNameFromType(elementType)
val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath
val arrayData =
Invoke(
MapObjects(
p => constructorFor(elementType, Some(p), newTypePath),
getPath,
schemaFor(elementType).dataType),
"array",
ObjectType(classOf[Array[Any]]))

val mapFunction: Expression => Expression = p => {
val converter = constructorFor(elementType, Some(p), newTypePath)
if (nullable) {
converter
} else {
AssertNotNull(converter, newTypePath)
}
}

val array = Invoke(
MapObjects(mapFunction, getPath, dataType),
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
arrayData :: Nil)
array :: Nil)

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
Expand Down Expand Up @@ -343,7 +352,7 @@ object ScalaReflection extends ScalaReflection {
newTypePath)

if (!nullable) {
AssertNotNull(constructor, t.toString, fieldName, fieldType.toString)
AssertNotNull(constructor, newTypePath)
} else {
constructor
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,7 @@ object ResolveUpCast extends Rule[LogicalPlan] {
fail(child, DateType, walkedTypePath)
case (StringType, to: NumericType) =>
fail(child, to, walkedTypePath)
case _ => Cast(child, dataType)
case _ => Cast(child, dataType.asNullable)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ object MapObjects {
* to handle collection elements.
* @param inputData An expression that when evaluted returns a collection object.
*/
case class MapObjects(
case class MapObjects private(
loopVar: LambdaVariable,
lambdaFunction: Expression,
inputData: Expression) extends Expression {
Expand Down Expand Up @@ -633,8 +633,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp
* `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all
* non-null `s`, `s.i` can't be null.
*/
case class AssertNotNull(
child: Expression, parentType: String, fieldName: String, fieldType: String)
case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
extends UnaryExpression {

override def dataType: DataType = child.dataType
Expand All @@ -647,6 +646,14 @@ case class AssertNotNull(
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val childGen = child.gen(ctx)

// This is going to be a string literal in generated java code, so we should escape `"` by `\"`
// and wrap every line with `"` at left side and `\n"` at right side, and finally concat them by
// ` + `.
val typePathString = walkedTypePath
.map(s => s.replaceAll("\"", "\\\\\""))
.map(s => '"' + s + "\\n\"")
.mkString(" + ")

ev.isNull = "false"
ev.value = childGen.value

Expand All @@ -655,7 +662,8 @@ case class AssertNotNull(

if (${childGen.isNull}) {
throw new RuntimeException(
"Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " +
"Null value appeared in non-nullable field:\\n" +
$typePathString +
"If the schema is inferred from a Scala tuple/case class, or a Java bean, " +
"please try to use scala.Option[_] or other nullable types " +
"(e.g. java.lang.Integer instead of int/scala.Int)."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class StringLongClass(a: String, b: Long)

Expand All @@ -32,94 +34,49 @@ case class StringIntClass(a: String, b: Int)
case class ComplexClass(a: Long, b: StringLongClass)

class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")

test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]
val cls = classOf[StringLongClass]


{
val attrs = Seq('a.string, 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
Seq(
toExternalString('a.string),
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
),
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))

{
val attrs = Seq('a.int, 'b.long)
val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
val expected = NewInstance(
cls,
Seq(
toExternalString('a.int.cast(StringType)),
AssertNotNull('b.long, cls.getName, "b", "Long")
),
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
}

test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val innerCls = classOf[StringLongClass]
val cls = classOf[ComplexClass]

val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
cls,
Seq(
AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"),
If(
'b.struct('a.int, 'b.long).isNull,
Literal.create(null, ObjectType(innerCls)),
NewInstance(
innerCls,
Seq(
toExternalString(
GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)),
AssertNotNull(
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
innerCls.getName, "b", "Long")),
ObjectType(innerCls),
propagateNull = false)
)),
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
}

test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val cls = classOf[StringLongClass]

val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression
val expected: Expression = NewInstance(
classOf[Tuple2[_, _]],
Seq(
NewInstance(
cls,
Seq(
toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))),
AssertNotNull(
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
cls.getName, "b", "Long")),
ObjectType(cls),
propagateNull = false),
'b.int.cast(LongType)),
ObjectType(classOf[Tuple2[_, _]]),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
}

test("nullability of array type element should not fail analysis") {
val encoder = ExpressionEncoder[Seq[Int]]
val attrs = 'a.array(IntegerType) :: Nil

// It should pass analysis
val bound = encoder.resolve(attrs, null).bind(attrs)

// If no null values appear, it should works fine
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))

// If there is null value, it should throw runtime exception
val e = intercept[RuntimeException] {
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}

test("the real number of fields doesn't match encoder schema: tuple encoder") {
Expand Down Expand Up @@ -166,10 +123,6 @@ class EncoderResolutionSuite extends PlanTest {
}
}

private def toExternalString(e: Expression): Expression = {
Invoke(e, "toString", ObjectType(classOf[String]), Nil)
}

test("throw exception if real type is not compatible with encoder schema") {
val msg1 = intercept[AnalysisException] {
ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -848,9 +848,7 @@ public void testRuntimeNullabilityCheck() {
}

nullabilityCheck.expect(RuntimeException.class);
nullabilityCheck.expectMessage(
"Null value appeared in non-nullable field " +
"test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
nullabilityCheck.expectMessage("Null value appeared in non-nullable field");

{
Row row = new GenericRow(new Object[] {
Expand Down
13 changes: 5 additions & 8 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1, 1, 1)
}


test("SPARK-12404: Datatype Helper Serializablity") {
val ds = sparkContext.parallelize((
new Timestamp(0),
new Date(0),
java.math.BigDecimal.valueOf(1),
scala.math.BigDecimal(1)) :: Nil).toDS()
new Timestamp(0),
new Date(0),
java.math.BigDecimal.valueOf(1),
scala.math.BigDecimal(1)) :: Nil).toDS()

ds.collect()
}
Expand Down Expand Up @@ -542,9 +541,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
buildDataset(Row(Row("hello", null))).collect()
}.getMessage

assert(message.contains(
"Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int."
))
assert(message.contains("Null value appeared in non-nullable field"))
}

test("SPARK-12478: top level null field") {
Expand Down