17
17
18
18
package org .apache .spark .serializer
19
19
20
- import java .io .{ NotSerializableException , ObjectOutput , ObjectStreamClass , ObjectStreamField }
20
+ import java .io ._
21
21
import java .lang .reflect .{Field , Method }
22
22
import java .security .AccessController
23
23
@@ -145,17 +145,25 @@ private[spark] object SerializationDebugger extends Logging {
145
145
// An object contains multiple slots in serialization.
146
146
// Get the slots and visit fields in all of them.
147
147
val (finalObj, desc) = findObjectAndDescriptor(o)
148
+
149
+ if (! finalObj.eq(o)) {
150
+ return visit(finalObj, s " writeReplace data (class: ${finalObj.getClass.getName}) " :: stack)
151
+ }
152
+
148
153
val slotDescs = desc.getSlotDescs
149
154
var i = 0
150
155
while (i < slotDescs.length) {
151
156
val slotDesc = slotDescs(i)
152
157
if (slotDesc.hasWriteObjectMethod) {
153
- // TODO: Handle classes that specify writeObject method.
158
+ val childStack = visitSerializableWithWriteObjectMethod(finalObj, slotDesc, stack)
159
+ if (childStack.nonEmpty) {
160
+ return childStack
161
+ }
154
162
} else {
155
163
val fields : Array [ObjectStreamField ] = slotDesc.getFields
156
164
val objFieldValues : Array [Object ] = new Array [Object ](slotDesc.getNumObjFields)
157
165
val numPrims = fields.length - objFieldValues.length
158
- desc .getObjFieldValues(finalObj, objFieldValues)
166
+ slotDesc .getObjFieldValues(finalObj, objFieldValues)
159
167
160
168
var j = 0
161
169
while (j < objFieldValues.length) {
@@ -169,12 +177,39 @@ private[spark] object SerializationDebugger extends Logging {
169
177
}
170
178
j += 1
171
179
}
172
-
173
180
}
174
181
i += 1
175
182
}
176
183
return List .empty
177
184
}
185
+
186
+ private def visitSerializableWithWriteObjectMethod (
187
+ o : Object , slotDesc : ObjectStreamClass , stack : List [String ]): List [String ] = {
188
+ println(" >>> processing serializable with writeObject" + o)
189
+ val innerObjectsCatcher = new ListObjectOutputStream
190
+ var notSerializableFound = false
191
+ try {
192
+ innerObjectsCatcher.writeObject(o)
193
+ } catch {
194
+ case io : IOException =>
195
+ notSerializableFound = true
196
+ }
197
+ if (notSerializableFound) {
198
+ val innerObjects = innerObjectsCatcher.outputArray
199
+ var k = 0
200
+ while (k < innerObjects.length) {
201
+ val elem = s " writeObject data (class: ${slotDesc.getName}) "
202
+ val childStack = visit(innerObjects(k), elem :: stack)
203
+ if (childStack.nonEmpty) {
204
+ return childStack
205
+ }
206
+ k += 1
207
+ }
208
+ } else {
209
+ visited ++= innerObjectsCatcher.outputArray
210
+ }
211
+ return List .empty
212
+ }
178
213
}
179
214
180
215
/**
@@ -220,6 +255,27 @@ private[spark] object SerializationDebugger extends Logging {
220
255
override def writeByte (i : Int ): Unit = {}
221
256
}
222
257
258
+ /** An output stream that emulates /dev/null */
259
+ private class NullOutputStream extends OutputStream {
260
+ override def write (b : Int ) { }
261
+ }
262
+
263
+ /**
264
+ * A dummy [[ObjectOutputStream ]] that saves the list of objects written to it and returns
265
+ * them through `outputArray`.
266
+ */
267
+ private class ListObjectOutputStream extends ObjectOutputStream (new NullOutputStream ) {
268
+ private val output = new mutable.ArrayBuffer [Any ]
269
+ this .enableReplaceObject(true )
270
+
271
+ def outputArray : Array [Any ] = output.toArray
272
+
273
+ override def replaceObject (obj : Object ): Object = {
274
+ output += obj
275
+ obj
276
+ }
277
+ }
278
+
223
279
/** An implicit class that allows us to call private methods of ObjectStreamClass. */
224
280
implicit class ObjectStreamClassMethods (val desc : ObjectStreamClass ) extends AnyVal {
225
281
def getSlotDescs : Array [ObjectStreamClass ] = {
0 commit comments