Skip to content

Commit 47ebea5

Browse files
marmbruspwendell
authored andcommitted
[SQL] SPARK-1364 Improve datatype and test coverage for ScalaReflection schema inference.
Author: Michael Armbrust <michael@databricks.com> Closes #293 from marmbrus/reflectTypes and squashes the following commits: f54e8e8 [Michael Armbrust] Improve datatype and test coverage for ScalaReflection schema inference.
1 parent 9c65fa7 commit 47ebea5

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,25 @@ object ScalaReflection {
4343
val params = t.member("<init>": TermName).asMethod.paramss
4444
StructType(
4545
params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true)))
46+
// Need to decide if we actually need a special type here.
47+
case t if t <:< typeOf[Array[Byte]] => BinaryType
48+
case t if t <:< typeOf[Array[_]] =>
49+
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
4650
case t if t <:< typeOf[Seq[_]] =>
4751
val TypeRef(_, _, Seq(elementType)) = t
4852
ArrayType(schemaFor(elementType))
53+
case t if t <:< typeOf[Map[_,_]] =>
54+
val TypeRef(_, _, Seq(keyType, valueType)) = t
55+
MapType(schemaFor(keyType), schemaFor(valueType))
4956
case t if t <:< typeOf[String] => StringType
5057
case t if t <:< definitions.IntTpe => IntegerType
5158
case t if t <:< definitions.LongTpe => LongType
59+
case t if t <:< definitions.FloatTpe => FloatType
5260
case t if t <:< definitions.DoubleTpe => DoubleType
5361
case t if t <:< definitions.ShortTpe => ShortType
5462
case t if t <:< definitions.ByteTpe => ByteType
63+
case t if t <:< definitions.BooleanTpe => BooleanType
64+
case t if t <:< typeOf[BigDecimal] => DecimalType
5565
}
5666

5767
implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.scalatest.FunSuite
21+
22+
import org.apache.spark.sql.test.TestSQLContext._
23+
24+
case class ReflectData(
25+
stringField: String,
26+
intField: Int,
27+
longField: Long,
28+
floatField: Float,
29+
doubleField: Double,
30+
shortField: Short,
31+
byteField: Byte,
32+
booleanField: Boolean,
33+
decimalField: BigDecimal,
34+
seqInt: Seq[Int])
35+
36+
case class ReflectBinary(data: Array[Byte])
37+
38+
class ScalaReflectionRelationSuite extends FunSuite {
39+
test("query case class RDD") {
40+
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
41+
BigDecimal(1), Seq(1,2,3))
42+
val rdd = sparkContext.parallelize(data :: Nil)
43+
rdd.registerAsTable("reflectData")
44+
45+
assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq)
46+
}
47+
48+
// Equality is broken for Arrays, so we test that separately.
49+
test("query binary data") {
50+
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
51+
rdd.registerAsTable("reflectBinary")
52+
53+
val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
54+
assert(result.toSeq === Seq[Byte](1))
55+
}
56+
}

0 commit comments

Comments
 (0)