Skip to content

[SPARK-13456][SQL] fix creating encoders for case classes defined in Spark shell #11410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
|val simpleSum = new Aggregator[Int, Int, Int] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aggregator already extends Serializable

| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
Expand Down Expand Up @@ -347,7 +347,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
|class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,32 @@ class ReplSuite extends SparkFunSuite {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,1024]",
"""
|val sqlContext = new org.apache.spark.sql.SQLContext(sc)
|import sqlContext.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
|
|// Test Dataset Serialization in the REPL
|Seq(TestCaseClass(1)).toDS().collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("Datasets and encoders") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
|ds.select(simpleSum).collect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your solution works, but it might be good to add a test case where the scope that gets captured has a side effect (i.e. create a file so if you double execute the outer scope it will fail the second time).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to add such a test, but couldn't figure it out. The outer scope is the line wrapper class that generated by scala REPL framework and I'm not sure how to insert side effect into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was just thinking you do something like create a file and then make sure that file handle gets used in the closure. If you got something wrong the second time the file is created it will fail.

""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
Expand Down Expand Up @@ -295,6 +317,31 @@ class ReplSuite extends SparkFunSuite {
}
}

test("Datasets agg type-inference") {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
|class SumOf[I, N : Numeric](f: I => N) extends
| org.apache.spark.sql.expressions.Aggregator[I, N, N] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import doesn't work here, not sure why

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
| override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
| override def finish(reduction: N): N = reduction
|}
|
|def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
|val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
|ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}

test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
Expand All @@ -317,4 +364,21 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("Exception", output)
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
}

test("line wrapper only initialized once when used as encoder outer scope") {
val output = runInterpreter("local",
"""
|val fileName = "repl-test-" + System.currentTimeMillis
|val tmpDir = System.getProperty("java.io.tmpdir")
|val file = new java.io.File(tmpDir, fileName)
|def createFile(): Unit = file.createNewFile()
|
|createFile();case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
|
|file.delete()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ class Analyzer(
if n.outerPointer.isEmpty &&
n.cls.isMemberClass &&
!Modifier.isStatic(n.cls.getModifiers) =>
val outer = OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
val outer = OuterScopes.getOuterScope(n.cls)
if (outer == null) {
throw new AnalysisException(
s"Unable to generate an encoder for inner class `${n.cls.getName}` without " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ import java.util.concurrent.ConcurrentMap

import com.google.common.collect.MapMaker

import org.apache.spark.util.Utils

object OuterScopes {
@transient
lazy val outerScopes: ConcurrentMap[String, AnyRef] =
new MapMaker().weakValues().makeMap()

/**
* Adds a new outer scope to this context that can be used when instantiating an `inner class`
* during deserialialization. Inner classes are created when a case class is defined in the
* during deserialization. Inner classes are created when a case class is defined in the
* Spark REPL and registering the outer scope that this class was defined in allows us to create
* new instances on the spark executors. In normal use, users should not need to call this
* function.
Expand All @@ -39,4 +41,47 @@ object OuterScopes {
def addOuterScope(outer: AnyRef): Unit = {
outerScopes.putIfAbsent(outer.getClass.getName, outer)
}

def getOuterScope(innerCls: Class[_]): AnyRef = {
assert(innerCls.isMemberClass)
val outerClassName = innerCls.getDeclaringClass.getName
val outer = outerScopes.get(outerClassName)
if (outer == null) {
outerClassName match {
// If the outer class is generated by REPL, users don't need to register it as it has
// only one instance and there is a way to retrieve it: get the `$read` object, call the
// `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()`
// method multiply times to get the single instance of the inner most `$iw` class.
case REPLClass(baseClassName) =>
val objClass = Utils.classForName(baseClassName + "$")
val objInstance = objClass.getField("MODULE$").get(null)
val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance)
val baseClass = Utils.classForName(baseClassName)

var getter = iwGetter(baseClass)
var obj = baseInstance
while (getter != null) {
obj = getter.invoke(obj)
getter = iwGetter(getter.getReturnType)
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @retronym , now I loop until there is no $iw method, is this looks good to you?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks cleaner to me.

outerScopes.putIfAbsent(outerClassName, obj)
obj
case _ => null
}
} else {
outer
}
}

private def iwGetter(cls: Class[_]) = {
try {
cls.getMethod("$iw")
} catch {
case _: NoSuchMethodException => null
}
}

// The format of REPL generated wrapper class's name, e.g. `$line12.$read$$iw$$iw`
private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r
}