Skip to content

Commit bbd8f5b

Browse files
ueshinmarmbrus
authored andcommitted
[SPARK-4245][SQL] Fix containsNull of the result ArrayType of CreateArray expression.
The `containsNull` of the result `ArrayType` of `CreateArray` should be `true` only if the children is empty or there exists nullable child. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #3110 from ueshin/issues/SPARK-4245 and squashes the following commits: 6f64746 [Takuya UESHIN] Move equalsIgnoreNullability method into DataType. 5a90e02 [Takuya UESHIN] Refine InsertIntoHiveType and add some comments. cbecba8 [Takuya UESHIN] Fix a test title. 884ec37 [Takuya UESHIN] Merge branch 'master' into issues/SPARK-4245 3c5274b [Takuya UESHIN] Add tests to insert data of types ArrayType / MapType / StructType with nullability is false into Hive table. 41a94a9 [Takuya UESHIN] Replace InsertIntoTable with InsertIntoHiveTable if data types ignoring nullability are same. 43e6ef5 [Takuya UESHIN] Fix containsNull for empty array. 778e997 [Takuya UESHIN] Fix containsNull of the result ArrayType of CreateArray expression.
1 parent ade72c4 commit bbd8f5b

File tree

5 files changed

+106
-2
lines changed

5 files changed

+106
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
115115

116116
override def dataType: DataType = {
117117
assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}")
118-
ArrayType(childTypes.headOption.getOrElse(NullType))
118+
ArrayType(
119+
childTypes.headOption.getOrElse(NullType),
120+
containsNull = children.exists(_.nullable))
119121
}
120122

121123
override def nullable: Boolean = false

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,27 @@ object DataType {
171171
case _ =>
172172
}
173173
}
174+
175+
/**
176+
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
177+
*/
178+
def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
179+
(left, right) match {
180+
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
181+
equalsIgnoreNullability(leftElementType, rightElementType)
182+
case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) =>
183+
equalsIgnoreNullability(leftKeyType, rightKeyType) &&
184+
equalsIgnoreNullability(leftValueType, rightValueType)
185+
case (StructType(leftFields), StructType(rightFields)) =>
186+
leftFields.size == rightFields.size &&
187+
leftFields.zip(rightFields)
188+
.forall{
189+
case (left, right) =>
190+
left.name == right.name && equalsIgnoreNullability(left.dataType, right.dataType)
191+
}
192+
case (left, right) => left == right
193+
}
194+
}
174195
}
175196

176197
abstract class DataType {

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,12 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
286286

287287
if (childOutputDataTypes == tableOutputDataTypes) {
288288
p
289+
} else if (childOutputDataTypes.size == tableOutputDataTypes.size &&
290+
childOutputDataTypes.zip(tableOutputDataTypes)
291+
.forall { case (left, right) => DataType.equalsIgnoreNullability(left, right) }) {
292+
// If both types ignoring nullability of ArrayType, MapType, StructType are the same,
293+
// use InsertIntoHiveTable instead of InsertIntoTable.
294+
InsertIntoHiveTable(p.table, p.partition, p.child, p.overwrite)
289295
} else {
290296
// Only do the casting when child output data types differ from table output data types.
291297
val castedChildOutput = child.output.zip(table.output).map {
@@ -316,6 +322,27 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
316322
override def unregisterAllTables() = {}
317323
}
318324

325+
/**
326+
* A logical plan representing insertion into Hive table.
327+
* This plan ignores nullability of ArrayType, MapType, StructType unlike InsertIntoTable
328+
* because Hive table doesn't have nullability for ARRAY, MAP, STRUCT types.
329+
*/
330+
private[hive] case class InsertIntoHiveTable(
331+
table: LogicalPlan,
332+
partition: Map[String, Option[String]],
333+
child: LogicalPlan,
334+
overwrite: Boolean)
335+
extends LogicalPlan {
336+
337+
override def children = child :: Nil
338+
override def output = child.output
339+
340+
override lazy val resolved = childrenResolved && child.output.zip(table.output).forall {
341+
case (childAttr, tableAttr) =>
342+
DataType.equalsIgnoreNullability(childAttr.dataType, tableAttr.dataType)
343+
}
344+
}
345+
319346
/**
320347
* :: DeveloperApi ::
321348
* Provides conversions between Spark SQL data types and Hive Metastore types.

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,11 @@ private[hive] trait HiveStrategies {
161161
object DataSinks extends Strategy {
162162
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
163163
case logical.InsertIntoTable(table: MetastoreRelation, partition, child, overwrite) =>
164-
InsertIntoHiveTable(table, partition, planLater(child), overwrite)(hiveContext) :: Nil
164+
execution.InsertIntoHiveTable(
165+
table, partition, planLater(child), overwrite)(hiveContext) :: Nil
166+
case hive.InsertIntoHiveTable(table: MetastoreRelation, partition, child, overwrite) =>
167+
execution.InsertIntoHiveTable(
168+
table, partition, planLater(child), overwrite)(hiveContext) :: Nil
165169
case logical.CreateTableAsSelect(
166170
Some(database), tableName, child, allowExisting, Some(extra: ASTNode)) =>
167171
CreateTableAsSelect(

sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,4 +121,54 @@ class InsertIntoHiveTableSuite extends QueryTest {
121121
sql("DROP TABLE table_with_partition")
122122
sql("DROP TABLE tmp_table")
123123
}
124+
125+
test("Insert ArrayType.containsNull == false") {
126+
val schema = StructType(Seq(
127+
StructField("a", ArrayType(StringType, containsNull = false))))
128+
val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
129+
val schemaRDD = applySchema(rowRDD, schema)
130+
schemaRDD.registerTempTable("tableWithArrayValue")
131+
sql("CREATE TABLE hiveTableWithArrayValue(a Array <STRING>)")
132+
sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue")
133+
134+
checkAnswer(
135+
sql("SELECT * FROM hiveTableWithArrayValue"),
136+
rowRDD.collect().toSeq)
137+
138+
sql("DROP TABLE hiveTableWithArrayValue")
139+
}
140+
141+
test("Insert MapType.valueContainsNull == false") {
142+
val schema = StructType(Seq(
143+
StructField("m", MapType(StringType, StringType, valueContainsNull = false))))
144+
val rowRDD = TestHive.sparkContext.parallelize(
145+
(1 to 100).map(i => Row(Map(s"key$i" -> s"value$i"))))
146+
val schemaRDD = applySchema(rowRDD, schema)
147+
schemaRDD.registerTempTable("tableWithMapValue")
148+
sql("CREATE TABLE hiveTableWithMapValue(m Map <STRING, STRING>)")
149+
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
150+
151+
checkAnswer(
152+
sql("SELECT * FROM hiveTableWithMapValue"),
153+
rowRDD.collect().toSeq)
154+
155+
sql("DROP TABLE hiveTableWithMapValue")
156+
}
157+
158+
test("Insert StructType.fields.exists(_.nullable == false)") {
159+
val schema = StructType(Seq(
160+
StructField("s", StructType(Seq(StructField("f", StringType, nullable = false))))))
161+
val rowRDD = TestHive.sparkContext.parallelize(
162+
(1 to 100).map(i => Row(Row(s"value$i"))))
163+
val schemaRDD = applySchema(rowRDD, schema)
164+
schemaRDD.registerTempTable("tableWithStructValue")
165+
sql("CREATE TABLE hiveTableWithStructValue(s Struct <f: STRING>)")
166+
sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue")
167+
168+
checkAnswer(
169+
sql("SELECT * FROM hiveTableWithStructValue"),
170+
rowRDD.collect().toSeq)
171+
172+
sql("DROP TABLE hiveTableWithStructValue")
173+
}
124174
}

0 commit comments

Comments
 (0)