Skip to content

Commit ee04a8b

Browse files
committed
[SPARK-5573][SQL] Add explode to dataframes
Author: Michael Armbrust <michael@databricks.com> Closes #4546 from marmbrus/explode and squashes the following commits: eefd33a [Michael Armbrust] whitespace a8d496c [Michael Armbrust] Merge remote-tracking branch 'apache/master' into explode 4af740e [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explode dc86a5c [Michael Armbrust] simple version d633d01 [Michael Armbrust] add scala specific 950707a [Michael Armbrust] fix comments ba8854c [Michael Armbrust] [SPARK-5573][SQL] Add explode to dataframes
1 parent c352ffb commit ee04a8b

File tree

5 files changed

+119
-2
lines changed

5 files changed

+119
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,25 @@ abstract class Generator extends Expression {
7373
}
7474
}
7575

76+
/**
77+
* A generator that produces its output using the provided lambda function.
78+
*/
79+
case class UserDefinedGenerator(
80+
schema: Seq[Attribute],
81+
function: Row => TraversableOnce[Row],
82+
children: Seq[Expression])
83+
extends Generator{
84+
85+
override protected def makeOutput(): Seq[Attribute] = schema
86+
87+
override def eval(input: Row): TraversableOnce[Row] = {
88+
val inputRow = new InterpretedProjection(children)
89+
function(inputRow(input))
90+
}
91+
92+
override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
93+
}
94+
7695
/**
7796
* Given an input array produces a sequence of rows for each value in the array.
7897
*/

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql
1919

2020
import scala.collection.JavaConversions._
2121
import scala.reflect.ClassTag
22+
import scala.reflect.runtime.universe.TypeTag
2223
import scala.util.control.NonFatal
2324

2425
import org.apache.spark.annotation.{DeveloperApi, Experimental}
@@ -441,6 +442,43 @@ trait DataFrame extends RDDApi[Row] with Serializable {
441442
sample(withReplacement, fraction, Utils.random.nextLong)
442443
}
443444

445+
/**
446+
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
447+
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
448+
* the input row are implicitly joined with each row that is output by the function.
449+
*
450+
* The following example uses this function to count the number of books which contain
451+
* a given word:
452+
*
453+
* {{{
454+
* case class Book(title: String, words: String)
455+
* val df: RDD[Book]
456+
*
457+
* case class Word(word: String)
458+
* val allWords = df.explode('words) {
459+
* case Row(words: String) => words.split(" ").map(Word(_))
460+
* }
461+
*
462+
* val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
463+
* }}}
464+
*/
465+
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame
466+
467+
468+
/**
469+
* (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero
470+
* or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
471+
* columns of the input row are implicitly joined with each value that is output by the function.
472+
*
473+
* {{{
474+
* df.explode("words", "word")(words: String => words.split(" "))
475+
* }}}
476+
*/
477+
def explode[A, B : TypeTag](
478+
inputColumn: String,
479+
outputColumn: String)(
480+
f: A => TraversableOnce[B]): DataFrame
481+
444482
/////////////////////////////////////////////////////////////////////////////
445483

446484
/**

sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.io.CharArrayWriter
2121

2222
import scala.language.implicitConversions
2323
import scala.reflect.ClassTag
24+
import scala.reflect.runtime.universe.TypeTag
2425
import scala.collection.JavaConversions._
2526

2627
import com.fasterxml.jackson.core.JsonFactory
@@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD
2930
import org.apache.spark.api.python.SerDeUtil
3031
import org.apache.spark.rdd.RDD
3132
import org.apache.spark.storage.StorageLevel
32-
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
33+
import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection}
3334
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
3435
import org.apache.spark.sql.catalyst.expressions._
3536
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
@@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD
3940
import org.apache.spark.sql.sources._
4041
import org.apache.spark.sql.types.{NumericType, StructType}
4142

42-
4343
/**
4444
* Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
4545
*/
@@ -282,6 +282,32 @@ private[sql] class DataFrameImpl protected[sql](
282282
Sample(fraction, withReplacement, seed, logicalPlan)
283283
}
284284

285+
override def explode[A <: Product : TypeTag]
286+
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
287+
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
288+
val attributes = schema.toAttributes
289+
val rowFunction =
290+
f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
291+
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
292+
293+
Generate(generator, join = true, outer = false, None, logicalPlan)
294+
}
295+
296+
override def explode[A, B : TypeTag](
297+
inputColumn: String,
298+
outputColumn: String)(
299+
f: A => TraversableOnce[B]): DataFrame = {
300+
val dataType = ScalaReflection.schemaFor[B].dataType
301+
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
302+
def rowFunction(row: Row) = {
303+
f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
304+
}
305+
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
306+
307+
Generate(generator, join = true, outer = false, None, logicalPlan)
308+
309+
}
310+
285311
/////////////////////////////////////////////////////////////////////////////
286312
// RDD API
287313
/////////////////////////////////////////////////////////////////////////////

sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import scala.reflect.ClassTag
21+
import scala.reflect.runtime.universe.TypeTag
2122

2223
import org.apache.spark.api.java.JavaRDD
2324
import org.apache.spark.rdd.RDD
@@ -110,6 +111,14 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
110111

111112
override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()
112113

114+
override def explode[A <: Product : TypeTag]
115+
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()
116+
117+
override def explode[A, B : TypeTag](
118+
inputColumn: String,
119+
outputColumn: String)(
120+
f: A => TraversableOnce[B]): DataFrame = err()
121+
113122
/////////////////////////////////////////////////////////////////////////////
114123

115124
override def head(n: Int): Array[Row] = err()

sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,31 @@ class DataFrameSuite extends QueryTest {
9898
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
9999
}
100100

101+
test("simple explode") {
102+
val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")
103+
104+
checkAnswer(
105+
df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
106+
Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
107+
)
108+
}
109+
110+
test("explode") {
111+
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
112+
val df2 =
113+
df.explode('letters) {
114+
case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
115+
}
116+
117+
checkAnswer(
118+
df2
119+
.select('_1 as 'letter, 'number)
120+
.groupBy('letter)
121+
.agg('letter, countDistinct('number)),
122+
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
123+
)
124+
}
125+
101126
test("selectExpr") {
102127
checkAnswer(
103128
testData.selectExpr("abs(key)", "value"),

0 commit comments

Comments
 (0)