Skip to content

Commit cb8c653

Browse files
committed
[SPARK-44799][CONNECT] Fix outer scopes resolution on the executor side
### What changes were proposed in this pull request? When you define a class in the REPL (with previously defines symbols), for example: ```scala val filePath = "my_path" case class MyTestClass(value: Int) ``` This is actually declared inside a command class.In ammonite the structure looks like this: ```scala // First command contains the `filePath` object cmd1 { val wrapper = new cmd1 val instance = new command.Helper } class cmd1 extends Serializable { class Helper extends Serializable { val filePath = "my_path" } } // Second contains the `MyTestClass` definition object command2 { val wrapper = new command2 val instance = new command.Helper } class command2 extends Serializable { _root_.scala.transient private val __amm_usedThings = _root_.ammonite.repl.ReplBridge.value.usedEarlierDefinitions.iterator.toSet private val `cmd1`: cmd1.instance.type = if (__amm_usedThings("""cmd1""")) cmd1 else null.asInstanceOf[cmd1.instance.type] class Helper extends Serializable { case class MyTestClass(value: Int) } } ``` In order to create an instance of `MyTestClass` we need an instance of the `Helper`. When an instance of the class is created by Spark itself we use `OuterScopes` that - for Ammonite generated classes - accesses the command object to fetch the helper instance. The problem with this, is that the access triggers the creation of an instance of the command, when you create an instance of the command this tries to access the REPL to figure out which one of its dependents is in use (clever compiler trick), and this fails because we are not running the REPL on the driver or executor in connect. This PR fixes this issue by explicitly passing an getter for the outer instance to the `ProductEncoder`. For ammonite we actually ship the helper instance. This way the encoder always carries the information it needs to create the class. ### Why are the changes needed? This fixes a bug when you try to use a REPL defined class as the input of the UDF. For example this will work now: ```scala val filePath = "my_path" // we need some previous cell that exposes a symbol that could be captured in the class definition. case class MyTestClass(value: Int) { override def toString: String = value.toString } spark.range(10).select(col("id").cast("int").as("value")).as[MyTestClass].map(mtc => mtc.value).collect() ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a test to `ReplE2ESuite` illustrate the issue. Closes apache#42489 from hvanhovell/SPARK-44799. Authored-by: Herman van Hovell <herman@databricks.com> Signed-off-by: Herman van Hovell <herman@databricks.com>
1 parent ed906e0 commit cb8c653

File tree

11 files changed

+58
-22
lines changed

11 files changed

+58
-22
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,8 @@ class Dataset[T] private[sql] (
883883
ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
884884
Seq(
885885
EncoderField(s"_1", this.agnosticEncoder, leftNullable, Metadata.empty),
886-
EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty)))
886+
EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty)),
887+
None)
887888

888889
sparkSession.newDataset(tupleEncoder) { builder =>
889890
val joinBuilder = builder.getJoinBuilder

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ private[sql] class SparkResult[T](
6060
RowEncoder
6161
.encoderFor(dataType.asInstanceOf[StructType])
6262
.asInstanceOf[AgnosticEncoder[E]]
63-
case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) =>
63+
case ProductEncoder(clsTag, fields, outer) if ProductEncoder.isTuple(clsTag) =>
6464
// Recursively continue updating the tuple product encoder
6565
val schema = dataType.asInstanceOf[StructType]
6666
assert(fields.length <= schema.fields.length)
6767
val updatedFields = fields.zipWithIndex.map { case (f, id) =>
6868
f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType))
6969
}
70-
ProductEncoder(clsTag, updatedFields)
70+
ProductEncoder(clsTag, updatedFields, outer)
7171
case _ =>
7272
enc
7373
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.arrow.vector.ipc.ArrowReader
3434
import org.apache.arrow.vector.util.Text
3535

3636
import org.apache.spark.sql.catalyst.ScalaReflection
37-
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
37+
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
3838
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
3939
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
4040
import org.apache.spark.sql.connect.client.CloseableIterator
@@ -288,9 +288,9 @@ object ArrowDeserializers {
288288
throw unsupportedCollectionType(tag.runtimeClass)
289289
}
290290

291-
case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
291+
case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) =>
292+
val outer = outerPointerGetter.map(_()).toSeq
292293
// We should try to make this work with MethodHandles.
293-
val outer = Option(OuterScopes.getOuterScope(tag.runtimeClass)).map(_()).toSeq
294294
val Some(constructor) =
295295
ScalaReflection.findConstructor(
296296
tag.runtimeClass,

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ object ArrowSerializer {
413413
serializerFor(value, structVector.getChild(MapVector.VALUE_NAME))) :: Nil)
414414
new ArraySerializer(v, extractor, structSerializer)
415415

416-
case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
416+
case (ProductEncoder(tag, fields, _), StructVectors(struct, vectors)) =>
417417
if (isSubClass(classOf[Product], tag)) {
418418
structSerializerFor(fields, struct, vectors) { (_, i) => p =>
419419
p.asInstanceOf[Product].productElement(i)

connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,19 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
283283
assertContains("""String = "[MyTestClass(1), MyTestClass(3)]"""", output)
284284
}
285285

286+
test("REPL class in encoder") {
287+
val input = """
288+
|case class MyTestClass(value: Int)
289+
|spark.range(3).
290+
| select(col("id").cast("int").as("value")).
291+
| as[MyTestClass].
292+
| map(mtc => mtc.value).
293+
| collect()
294+
""".stripMargin
295+
val output = runCommandsInShell(input)
296+
assertContains("Array[Int] = Array(0, 1, 2)", output)
297+
}
298+
286299
test("REPL class in UDF") {
287300
val input = """
288301
|case class MyTestClass(value: Int)

sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.commons.lang3.reflect.ConstructorUtils
2828

2929
import org.apache.spark.internal.Logging
3030
import org.apache.spark.sql.Row
31-
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
31+
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
3232
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
3333
import org.apache.spark.sql.errors.ExecutionErrors
3434
import org.apache.spark.sql.types._
@@ -394,7 +394,8 @@ object ScalaReflection extends ScalaReflection {
394394
isRowEncoderSupported)
395395
EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
396396
}
397-
ProductEncoder(ClassTag(getClassFromType(t)), params)
397+
val cls = getClassFromType(t)
398+
ProductEncoder(ClassTag(cls), params, Option(OuterScopes.getOuterScope(cls)))
398399
case _ =>
399400
throw ExecutionErrors.cannotFindEncoderForTypeError(tpe.toString)
400401
}

sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ object AgnosticEncoders {
113113
// This supports both Product and DefinedByConstructorParams
114114
case class ProductEncoder[K](
115115
override val clsTag: ClassTag[K],
116-
override val fields: Seq[EncoderField]) extends StructEncoder[K]
116+
override val fields: Seq[EncoderField],
117+
outerPointerGetter: Option[() => AnyRef]) extends StructEncoder[K]
117118

118119
object ProductEncoder {
119120
val cachedCls = new ConcurrentHashMap[Int, Class[_]]
@@ -123,7 +124,7 @@ object AgnosticEncoders {
123124
}
124125
val cls = cachedCls.computeIfAbsent(encoders.size,
125126
_ => SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}"))
126-
ProductEncoder[Any](ClassTag(cls), fields)
127+
ProductEncoder[Any](ClassTag(cls), fields, None)
127128
}
128129

129130
private[sql] def isTuple(tag: ClassTag[_]): Boolean = {

sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,32 @@ object OuterScopes {
8282
if (outer == null) {
8383
outerClassName match {
8484
case AmmoniteREPLClass(cellClassName) =>
85-
() => {
86-
val objClass = SparkClassUtils.classForName(cellClassName)
87-
val objInstance = objClass.getField("MODULE$").get(null)
88-
val obj = objClass.getMethod("instance").invoke(objInstance)
89-
addOuterScope(obj)
90-
obj
91-
}
85+
/* A short introduction to Ammonite class generation.
86+
*
87+
* There are three classes generated for each command:
88+
* - The command. This contains all the dependencies needed to execute the command. It
89+
* also contains some logic to only initialize dependencies it needs, the others will
90+
* be null. This logic is powered by the compiler, and it will only work when there is
91+
* an Ammonite REPL bound through the ReplBridge; it will fail with a
92+
* NullPointerException when this is not the case.
93+
* - The Helper. This contains the user code. This is an inner class of the command. If
94+
* it needs one of its dependencies it will pull them from the command. The helper
95+
* instance is needed when a class is defined in the user code. This where this code
96+
* comes in, it resolves the Helper instance.
97+
* - The command companion object. This holds an instance of the Helper class and the
98+
* command. When you touch the command companion on a machine where the REPL is not
99+
* running (driver and executors for connect), and the command has dependencies, the
100+
* initialization of the command will fail because it cannot use the REPL to figure out
101+
* which dependencies to retain.
102+
*
103+
* To by-pass the problem with executor side helper resolution, we eagerly capture the
104+
* helper instance here.
105+
*/
106+
val objClass = SparkClassUtils.classForName(cellClassName)
107+
val objInstance = objClass.getField("MODULE$").get(null)
108+
val obj = objClass.getMethod("instance").invoke(objInstance)
109+
addOuterScope(obj)
110+
() => obj
92111
// If the outer class is generated by REPL, users don't need to register it as it has
93112
// only one instance and there is a way to retrieve it: get the `$read` object, call the
94113
// `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()`

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ object DeserializerBuildHelper {
350350
createDeserializer(valueEncoder, _, newTypePath),
351351
tag.runtimeClass)
352352

353-
case ProductEncoder(tag, fields) =>
353+
case ProductEncoder(tag, fields, outerPointerGetter) =>
354354
val cls = tag.runtimeClass
355355
val dt = ObjectType(cls)
356356
val isTuple = cls.getName.startsWith("scala.Tuple")
@@ -373,7 +373,7 @@ object DeserializerBuildHelper {
373373
exprs.If(
374374
IsNull(path),
375375
exprs.Literal.create(null, dt),
376-
NewInstance(cls, arguments, dt, propagateNull = false))
376+
NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter))
377377

378378
case AgnosticEncoders.RowEncoder(fields) =>
379379
val convertedFields = fields.zipWithIndex.map { case (f, i) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ object SerializerBuildHelper {
347347
validateAndSerializeElement(valueEncoder, valueContainsNull))
348348
)
349349

350-
case ProductEncoder(_, fields) =>
350+
case ProductEncoder(_, fields, _) =>
351351
val serializedFields = fields.map { field =>
352352
// SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul
353353
// is necessary here. Because for a nullable nested inputObject with struct data

0 commit comments

Comments
 (0)