Skip to content

Commit b3e5af6

Browse files
committed
[SPARK-13898][SQL] Merge DatasetHolder and DataFrameHolder
## What changes were proposed in this pull request? This patch merges DatasetHolder and DataFrameHolder. This makes more sense because DataFrame/Dataset are now one class. In addition, fixed some minor issues with pull request #11732. ## How was this patch tested? Updated existing unit tests that test these implicits. Author: Reynold Xin <rxin@databricks.com> Closes #11737 from rxin/SPARK-13898.
1 parent 5e86e92 commit b3e5af6

File tree

14 files changed

+36
-137
lines changed

14 files changed

+36
-137
lines changed

project/MimaExcludes.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -300,12 +300,6 @@ object MimaExcludes {
300300
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
301301
) ++ Seq(
302302
// [SPARK-13244][SQL] Migrates DataFrame to Dataset
303-
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameHolder.apply"),
304-
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameHolder.toDF"),
305-
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameHolder.copy"),
306-
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameHolder.copy$default$1"),
307-
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameHolder.df$1"),
308-
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameHolder.this"),
309303
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"),
310304
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"),
311305
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"),
@@ -315,6 +309,13 @@ object MimaExcludes {
315309
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"),
316310
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"),
317311
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"),
312+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"),
313+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"),
314+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"),
315+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"),
316+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"),
317+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"),
318+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"),
318319
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
319320

320321
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),

sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala

Lines changed: 0 additions & 37 deletions
This file was deleted.

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2196,14 +2196,12 @@ class Dataset[T] private[sql](
21962196
def write: DataFrameWriter = new DataFrameWriter(toDF())
21972197

21982198
/**
2199-
* Returns the content of the [[Dataset]] as a [[Dataset]] of JSON strings.
2200-
*
2201-
* @group basic
2202-
* @since 1.6.0
2199+
* Returns the content of the [[Dataset]] as a Dataset of JSON strings.
2200+
* @since 2.0.0
22032201
*/
22042202
def toJSON: Dataset[String] = {
22052203
val rowSchema = this.schema
2206-
val rdd = queryExecution.toRdd.mapPartitions { iter =>
2204+
val rdd: RDD[String] = queryExecution.toRdd.mapPartitions { iter =>
22072205
val writer = new CharArrayWriter()
22082206
// create the Generator without separator inserted between 2 records
22092207
val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
@@ -2225,8 +2223,8 @@ class Dataset[T] private[sql](
22252223
}
22262224
}
22272225
}
2228-
import sqlContext.implicits._
2229-
rdd.toDS
2226+
import sqlContext.implicits.newStringEncoder
2227+
sqlContext.createDataset(rdd)
22302228
}
22312229

22322230
/**

sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
/**
21-
* A container for a [[Dataset]], used for implicit conversions.
21+
* A container for a [[Dataset]], used for implicit conversions in Scala.
2222
*
2323
* To use this, import implicit conversions in SQL:
2424
* {{{
@@ -32,4 +32,10 @@ case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) {
3232
// This is declared with parentheses to prevent the Scala compiler from treating
3333
// `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset.
3434
def toDS(): Dataset[T] = ds
35+
36+
// This is declared with parentheses to prevent the Scala compiler from treating
37+
// `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
38+
def toDF(): DataFrame = ds.toDF()
39+
40+
def toDF(colNames: String*): DataFrame = ds.toDF(colNames : _*)
3541
}

sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -147,75 +147,4 @@ abstract class SQLImplicits {
147147
*/
148148
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
149149

150-
/**
151-
* Creates a DataFrame from an RDD of Product (e.g. case classes, tuples).
152-
* @since 1.3.0
153-
*/
154-
implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
155-
DataFrameHolder(_sqlContext.createDataFrame(rdd))
156-
}
157-
158-
/**
159-
* Creates a DataFrame from a local Seq of Product.
160-
* @since 1.3.0
161-
*/
162-
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
163-
{
164-
DataFrameHolder(_sqlContext.createDataFrame(data))
165-
}
166-
167-
// Do NOT add more implicit conversions for primitive types.
168-
// They are likely to break source compatibility by making existing implicit conversions
169-
// ambiguous. In particular, RDD[Double] is dangerous because of [[DoubleRDDFunctions]].
170-
171-
/**
172-
* Creates a single column DataFrame from an RDD[Int].
173-
* @since 1.3.0
174-
*/
175-
implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
176-
val dataType = IntegerType
177-
val rows = data.mapPartitions { iter =>
178-
val row = new SpecificMutableRow(dataType :: Nil)
179-
iter.map { v =>
180-
row.setInt(0, v)
181-
row: InternalRow
182-
}
183-
}
184-
DataFrameHolder(
185-
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
186-
}
187-
188-
/**
189-
* Creates a single column DataFrame from an RDD[Long].
190-
* @since 1.3.0
191-
*/
192-
implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
193-
val dataType = LongType
194-
val rows = data.mapPartitions { iter =>
195-
val row = new SpecificMutableRow(dataType :: Nil)
196-
iter.map { v =>
197-
row.setLong(0, v)
198-
row: InternalRow
199-
}
200-
}
201-
DataFrameHolder(
202-
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
203-
}
204-
205-
/**
206-
* Creates a single column DataFrame from an RDD[String].
207-
* @since 1.3.0
208-
*/
209-
implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
210-
val dataType = StringType
211-
val rows = data.mapPartitions { iter =>
212-
val row = new SpecificMutableRow(dataType :: Nil)
213-
iter.map { v =>
214-
row.update(0, UTF8String.fromString(v))
215-
row: InternalRow
216-
}
217-
}
218-
DataFrameHolder(
219-
_sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
220-
}
221150
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,15 +612,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
612612
val longString = Array.fill(21)("1").mkString
613613
val df = sparkContext.parallelize(Seq("1", longString)).toDF()
614614
val expectedAnswerForFalse = """+---------------------+
615-
||_1 |
615+
||value |
616616
|+---------------------+
617617
||1 |
618618
||111111111111111111111|
619619
|+---------------------+
620620
|""".stripMargin
621621
assert(df.showString(10, false) === expectedAnswerForFalse)
622622
val expectedAnswerForTrue = """+--------------------+
623-
|| _1|
623+
|| value|
624624
|+--------------------+
625625
|| 1|
626626
||11111111111111111...|

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,15 +1621,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
16211621
}
16221622

16231623
test("SPARK-10215 Div of Decimal returns null") {
1624-
val d = Decimal(1.12321)
1624+
val d = Decimal(1.12321).toBigDecimal
16251625
val df = Seq((d, 1)).toDF("a", "b")
16261626

16271627
checkAnswer(
16281628
df.selectExpr("b * a / b"),
1629-
Seq(Row(d.toBigDecimal)))
1629+
Seq(Row(d)))
16301630
checkAnswer(
16311631
df.selectExpr("b * a / b / b"),
1632-
Seq(Row(d.toBigDecimal)))
1632+
Seq(Row(d)))
16331633
checkAnswer(
16341634
df.selectExpr("b * a + b"),
16351635
Seq(Row(BigDecimal(2.12321))))
@@ -1638,7 +1638,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
16381638
Seq(Row(BigDecimal(0.12321))))
16391639
checkAnswer(
16401640
df.selectExpr("b * a * b"),
1641-
Seq(Row(d.toBigDecimal)))
1641+
Seq(Row(d)))
16421642
}
16431643

16441644
test("precision smaller than scale") {

sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
2525
import org.apache.spark.sql.test.SharedSQLContext
2626

2727
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
28-
import testImplicits.localSeqToDataFrameHolder
28+
import testImplicits._
2929

3030
test("shuffling UnsafeRows in exchange") {
3131
val input = (1 to 1000).map(Tuple1.apply)

sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ import org.apache.spark.sql.types._
3030
* sorted by a reference implementation ([[ReferenceSort]]).
3131
*/
3232
class SortSuite extends SparkPlanTest with SharedSQLContext {
33-
import testImplicits.localSeqToDataFrameHolder
33+
import testImplicits.newProductEncoder
34+
import testImplicits.localSeqToDatasetHolder
3435

3536
test("basic sorting using ExternalSort") {
3637

sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
126126
test("decimal type") {
127127
// Casting is required here because ScalaReflection can't capture decimal precision information.
128128
val df = (1 to 10)
129-
.map(i => Tuple1(Decimal(i, 15, 10)))
129+
.map(i => Tuple1(Decimal(i, 15, 10).toJavaBigDecimal))
130130
.toDF("dec")
131131
.select($"dec" cast DecimalType(15, 10))
132132

0 commit comments

Comments
 (0)