Skip to content

[SPARK-5573][SQL] Add explode to dataframes #4546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,25 @@ abstract class Generator extends Expression {
}
}

/**
* A generator that produces its output using the provided lambda function.
*/
case class UserDefinedGenerator(
schema: Seq[Attribute],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
extends Generator{

override protected def makeOutput(): Seq[Attribute] = schema

override def eval(input: Row): TraversableOnce[Row] = {
val inputRow = new InterpretedProjection(children)
function(inputRow(input))
}

override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
}

/**
* Given an input array produces a sequence of rows for each value in the array.
*/
Expand Down
38 changes: 38 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql

import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal

import org.apache.spark.annotation.{DeveloperApi, Experimental}
Expand Down Expand Up @@ -441,6 +442,43 @@ trait DataFrame extends RDDApi[Row] with Serializable {
sample(withReplacement, fraction, Utils.random.nextLong)
}

/**
* (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
* rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
* the input row are implicitly joined with each row that is output by the function.
*
* The following example uses this function to count the number of books which contain
* a given word:
*
* {{{
* case class Book(title: String, words: String)
* val df: RDD[Book]
*
* case class Word(word: String)
* val allWords = df.explode('words) {
* case Row(words: String) => words.split(" ").map(Word(_))
* }
*
* val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
* }}}
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame


/**
* (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero
* or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
* columns of the input row are implicitly joined with each value that is output by the function.
*
* {{{
* df.explode("words", "word")(words: String => words.split(" "))
* }}}
*/
def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame

/////////////////////////////////////////////////////////////////////////////

/**
Expand Down
30 changes: 28 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.CharArrayWriter

import scala.language.implicitConversions
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.collection.JavaConversions._

import com.fasterxml.jackson.core.JsonFactory
Expand All @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.{expressions, SqlParser, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
Expand All @@ -39,7 +40,6 @@ import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{NumericType, StructType}


/**
* Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
*/
Expand Down Expand Up @@ -282,6 +282,32 @@ private[sql] class DataFrameImpl protected[sql](
Sample(fraction, withReplacement, seed, logicalPlan)
}

override def explode[A <: Product : TypeTag]
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributes = schema.toAttributes
val rowFunction =
f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))

Generate(generator, join = true, outer = false, None, logicalPlan)
}

override def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
def rowFunction(row: Row) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you define this as a val, i think it doesn't capture outer?

f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)

Generate(generator, join = true, outer = false, None, logicalPlan)

}

/////////////////////////////////////////////////////////////////////////////
// RDD API
/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -110,6 +111,14 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten

override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()

override def explode[A <: Product : TypeTag]
(input: Column*)(f: Row => TraversableOnce[A]): DataFrame = err()

override def explode[A, B : TypeTag](
inputColumn: String,
outputColumn: String)(
f: A => TraversableOnce[B]): DataFrame = err()

/////////////////////////////////////////////////////////////////////////////

override def head(n: Int): Array[Row] = err()
Expand Down
25 changes: 25 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ class DataFrameSuite extends QueryTest {
sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
}

test("simple explode") {
val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDataFrame("words")

checkAnswer(
df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
)
}

test("explode") {
val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDataFrame("number", "letters")
val df2 =
df.explode('letters) {
case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
}

checkAnswer(
df2
.select('_1 as 'letter, 'number)
.groupBy('letter)
.agg('letter, countDistinct('number)),
Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
)
}

test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
Expand Down