Skip to content

Commit 3d4ce67

Browse files
committed
Merge pull request #4 from gzm0/typedSql
Remove intermediate map for records. Allow serialization
2 parents 457d699 + c6c60e3 commit 3d4ce67

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

sql/core/src/main/scala/org/apache/spark/sql/TypedSql.scala

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import org.apache.spark.sql.catalyst.types._
88
import scala.language.experimental.macros
99

1010
import records._
11+
import Macros.RecordMacros
1112

1213
import org.apache.spark.annotation.Experimental
1314
import org.apache.spark.rdd.RDD
@@ -89,17 +90,46 @@ object SQLMacros {
8990

9091
val analyzedPlan = analyzer(logicalPlan)
9192

92-
val fields = analyzedPlan.output.zipWithIndex.map {
93-
case (attr, i) =>
94-
q"""${attr.name} -> row.${newTermName("get" + primitiveForType(attr.dataType))}($i)"""
93+
// TODO: This shouldn't probably be here but somewhere generic
94+
// which defines the catalyst <-> Scala type mapping
95+
def toScalaType(dt: DataType) = dt match {
96+
case IntegerType => definitions.IntTpe
97+
case LongType => definitions.LongTpe
98+
case ShortType => definitions.ShortTpe
99+
case ByteType => definitions.ByteTpe
100+
case DoubleType => definitions.DoubleTpe
101+
case FloatType => definitions.FloatTpe
102+
case BooleanType => definitions.BooleanTpe
103+
case StringType => definitions.StringClass.toType
95104
}
96105

106+
val schema = analyzedPlan.output.map(attr => (attr.name, toScalaType(attr.dataType)))
107+
val dataImpl = {
108+
// Generate a case for each field
109+
val cases = analyzedPlan.output.zipWithIndex.map {
110+
case (attr, i) =>
111+
cq"""${attr.name} => row.${newTermName("get" + primitiveForType(attr.dataType))}($i)"""
112+
}
113+
114+
// Implement __data using these cases.
115+
// TODO: Unfortunately, this still boxes. We cannot resolve this
116+
// since the R abstraction depends on the fully generic __data.
117+
// The only way to change this is to create __dataLong, etc. on
118+
// R itself
119+
q"""
120+
val res = fieldName match {
121+
case ..$cases
122+
case _ => ???
123+
}
124+
res.asInstanceOf[T]
125+
"""
126+
}
127+
128+
val record: c.Expr[Nothing] = new RecordMacros[c.type](c).record(schema)(tq"Serializable")()(dataImpl)
97129
val tree = q"""
98-
import records.R
99130
..${args.zipWithIndex.map{ case (r,i) => q"""$r.registerAsTable(${s"table$i"})""" }}
100131
val result = sql($query)
101-
// TODO: Avoid double copy
102-
result.map(row => R(..$fields))
132+
result.map(row => $record)
103133
"""
104134

105135
println(tree)

sql/core/src/test/scala/org/apache/spark/sql/TypedSqlSuite.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,9 @@ class TypedSqlSuite extends FunSuite {
4343

4444
ignore("nested results") { }
4545

46-
ignore("join query") {
47-
val results = sql"""
48-
SELECT a.name
49-
FROM $people a
50-
JOIN $people b ON a.age = b.age
51-
"""
52-
// TODO: R is not serializable.
53-
// assert(results.first().name == "Michael")
46+
test("join query") {
47+
val results = sql"""SELECT a.name FROM $people a JOIN $people b ON a.age = b.age"""
48+
49+
assert(results.first().name == "Michael")
5450
}
5551
}

0 commit comments

Comments
 (0)