Skip to content

Commit 409d226

Browse files
committed
Simplify JSON de/serialization for BlockId
1 parent ada310a commit 409d226

File tree

2 files changed

+72
-146
lines changed

2 files changed

+72
-146
lines changed

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 2 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ private[spark] object JsonProtocol {
195195
taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing)
196196
val updatedBlocks = taskMetrics.updatedBlocks.map { blocks =>
197197
JArray(blocks.toList.map { case (id, status) =>
198-
("Block ID" -> blockIdToJson(id)) ~
198+
("Block ID" -> id.toString) ~
199199
("Status" -> blockStatusToJson(status))
200200
})
201201
}.getOrElse(JNothing)
@@ -284,35 +284,6 @@ private[spark] object JsonProtocol {
284284
("Replication" -> storageLevel.replication)
285285
}
286286

287-
def blockIdToJson(blockId: BlockId): JValue = {
288-
val blockType = Utils.getFormattedClassName(blockId)
289-
val json: JObject = blockId match {
290-
case rddBlockId: RDDBlockId =>
291-
("RDD ID" -> rddBlockId.rddId) ~
292-
("Split Index" -> rddBlockId.splitIndex)
293-
case shuffleBlockId: ShuffleBlockId =>
294-
("Shuffle ID" -> shuffleBlockId.shuffleId) ~
295-
("Map ID" -> shuffleBlockId.mapId) ~
296-
("Reduce ID" -> shuffleBlockId.reduceId)
297-
case broadcastBlockId: BroadcastBlockId =>
298-
"Broadcast ID" -> broadcastBlockId.broadcastId
299-
case broadcastHelperBlockId: BroadcastHelperBlockId =>
300-
("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~
301-
("Helper Type" -> broadcastHelperBlockId.hType)
302-
case taskResultBlockId: TaskResultBlockId =>
303-
"Task ID" -> taskResultBlockId.taskId
304-
case streamBlockId: StreamBlockId =>
305-
("Stream ID" -> streamBlockId.streamId) ~
306-
("Unique ID" -> streamBlockId.uniqueId)
307-
case tempBlockId: TempBlockId =>
308-
val uuid = UUIDToJson(tempBlockId.id)
309-
"Temp ID" -> uuid
310-
case testBlockId: TestBlockId =>
311-
"Test ID" -> testBlockId.id
312-
}
313-
("Type" -> blockType) ~ json
314-
}
315-
316287
def blockStatusToJson(blockStatus: BlockStatus): JValue = {
317288
val storageLevel = storageLevelToJson(blockStatus.storageLevel)
318289
("Storage Level" -> storageLevel) ~
@@ -513,7 +484,7 @@ private[spark] object JsonProtocol {
513484
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
514485
metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value =>
515486
value.extract[List[JValue]].map { block =>
516-
val id = blockIdFromJson(block \ "Block ID")
487+
val id = BlockId((block \ "Block ID").extract[String])
517488
val status = blockStatusFromJson(block \ "Status")
518489
(id, status)
519490
}
@@ -616,50 +587,6 @@ private[spark] object JsonProtocol {
616587
StorageLevel(useDisk, useMemory, deserialized, replication)
617588
}
618589

619-
def blockIdFromJson(json: JValue): BlockId = {
620-
val rddBlockId = Utils.getFormattedClassName(RDDBlockId)
621-
val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId)
622-
val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId)
623-
val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId)
624-
val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId)
625-
val streamBlockId = Utils.getFormattedClassName(StreamBlockId)
626-
val tempBlockId = Utils.getFormattedClassName(TempBlockId)
627-
val testBlockId = Utils.getFormattedClassName(TestBlockId)
628-
629-
(json \ "Type").extract[String] match {
630-
case `rddBlockId` =>
631-
val rddId = (json \ "RDD ID").extract[Int]
632-
val splitIndex = (json \ "Split Index").extract[Int]
633-
new RDDBlockId(rddId, splitIndex)
634-
case `shuffleBlockId` =>
635-
val shuffleId = (json \ "Shuffle ID").extract[Int]
636-
val mapId = (json \ "Map ID").extract[Int]
637-
val reduceId = (json \ "Reduce ID").extract[Int]
638-
new ShuffleBlockId(shuffleId, mapId, reduceId)
639-
case `broadcastBlockId` =>
640-
val broadcastId = (json \ "Broadcast ID").extract[Long]
641-
new BroadcastBlockId(broadcastId)
642-
case `broadcastHelperBlockId` =>
643-
val broadcastBlockId =
644-
blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId]
645-
val hType = (json \ "Helper Type").extract[String]
646-
new BroadcastHelperBlockId(broadcastBlockId, hType)
647-
case `taskResultBlockId` =>
648-
val taskId = (json \ "Task ID").extract[Long]
649-
new TaskResultBlockId(taskId)
650-
case `streamBlockId` =>
651-
val streamId = (json \ "Stream ID").extract[Int]
652-
val uniqueId = (json \ "Unique ID").extract[Long]
653-
new StreamBlockId(streamId, uniqueId)
654-
case `tempBlockId` =>
655-
val tempId = UUIDFromJson(json \ "Temp ID")
656-
new TempBlockId(tempId)
657-
case `testBlockId` =>
658-
val testId = (json \ "Test ID").extract[String]
659-
new TestBlockId(testId)
660-
}
661-
}
662-
663590
def blockStatusFromJson(json: JValue): BlockStatus = {
664591
val storageLevel = storageLevelFromJson(json \ "Storage Level")
665592
val memorySize = (json \ "Memory Size").extract[Long]

core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala

Lines changed: 70 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ class JsonProtocolSuite extends FunSuite {
112112
testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark"))
113113
testBlockId(TaskResultBlockId(1L))
114114
testBlockId(StreamBlockId(1, 2L))
115-
testBlockId(TempBlockId(UUID.randomUUID()))
116115
}
117116

118117

@@ -168,8 +167,8 @@ class JsonProtocolSuite extends FunSuite {
168167
}
169168

170169
private def testBlockId(blockId: BlockId) {
171-
val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId))
172-
blockId == newBlockId
170+
val newBlockId = BlockId(blockId.toString)
171+
assert(blockId === newBlockId)
173172
}
174173

175174

@@ -180,90 +179,90 @@ class JsonProtocolSuite extends FunSuite {
180179
private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) {
181180
(event1, event2) match {
182181
case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) =>
183-
assert(e1.properties == e2.properties)
182+
assert(e1.properties === e2.properties)
184183
assertEquals(e1.stageInfo, e2.stageInfo)
185184
case (e1: SparkListenerStageCompleted, e2: SparkListenerStageCompleted) =>
186185
assertEquals(e1.stageInfo, e2.stageInfo)
187186
case (e1: SparkListenerTaskStart, e2: SparkListenerTaskStart) =>
188-
assert(e1.stageId == e2.stageId)
187+
assert(e1.stageId === e2.stageId)
189188
assertEquals(e1.taskInfo, e2.taskInfo)
190189
case (e1: SparkListenerTaskGettingResult, e2: SparkListenerTaskGettingResult) =>
191190
assertEquals(e1.taskInfo, e2.taskInfo)
192191
case (e1: SparkListenerTaskEnd, e2: SparkListenerTaskEnd) =>
193-
assert(e1.stageId == e2.stageId)
194-
assert(e1.taskType == e2.taskType)
192+
assert(e1.stageId === e2.stageId)
193+
assert(e1.taskType === e2.taskType)
195194
assertEquals(e1.reason, e2.reason)
196195
assertEquals(e1.taskInfo, e2.taskInfo)
197196
assertEquals(e1.taskMetrics, e2.taskMetrics)
198197
case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) =>
199-
assert(e1.jobId == e2.jobId)
200-
assert(e1.properties == e2.properties)
201-
assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 == i2))
198+
assert(e1.jobId === e2.jobId)
199+
assert(e1.properties === e2.properties)
200+
assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2))
202201
case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) =>
203-
assert(e1.jobId == e2.jobId)
202+
assert(e1.jobId === e2.jobId)
204203
assertEquals(e1.jobResult, e2.jobResult)
205204
case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) =>
206205
assertEquals(e1.environmentDetails, e2.environmentDetails)
207206
case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) =>
208-
assert(e1.maxMem == e2.maxMem)
207+
assert(e1.maxMem === e2.maxMem)
209208
assertEquals(e1.blockManagerId, e2.blockManagerId)
210209
case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) =>
211210
assertEquals(e1.blockManagerId, e2.blockManagerId)
212211
case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) =>
213-
assert(e1.rddId == e2.rddId)
212+
assert(e1.rddId === e2.rddId)
214213
case (SparkListenerShutdown, SparkListenerShutdown) =>
215214
case _ => fail("Events don't match in types!")
216215
}
217216
}
218217

219218
private def assertEquals(info1: StageInfo, info2: StageInfo) {
220-
assert(info1.stageId == info2.stageId)
221-
assert(info1.name == info2.name)
222-
assert(info1.numTasks == info2.numTasks)
223-
assert(info1.submissionTime == info2.submissionTime)
224-
assert(info1.completionTime == info2.completionTime)
225-
assert(info1.emittedTaskSizeWarning == info2.emittedTaskSizeWarning)
219+
assert(info1.stageId === info2.stageId)
220+
assert(info1.name === info2.name)
221+
assert(info1.numTasks === info2.numTasks)
222+
assert(info1.submissionTime === info2.submissionTime)
223+
assert(info1.completionTime === info2.completionTime)
224+
assert(info1.emittedTaskSizeWarning === info2.emittedTaskSizeWarning)
226225
assertEquals(info1.rddInfo, info2.rddInfo)
227226
}
228227

229228
private def assertEquals(info1: RDDInfo, info2: RDDInfo) {
230-
assert(info1.id == info2.id)
231-
assert(info1.name == info2.name)
232-
assert(info1.numPartitions == info2.numPartitions)
233-
assert(info1.numCachedPartitions == info2.numCachedPartitions)
234-
assert(info1.memSize == info2.memSize)
235-
assert(info1.diskSize == info2.diskSize)
229+
assert(info1.id === info2.id)
230+
assert(info1.name === info2.name)
231+
assert(info1.numPartitions === info2.numPartitions)
232+
assert(info1.numCachedPartitions === info2.numCachedPartitions)
233+
assert(info1.memSize === info2.memSize)
234+
assert(info1.diskSize === info2.diskSize)
236235
assertEquals(info1.storageLevel, info2.storageLevel)
237236
}
238237

239238
private def assertEquals(level1: StorageLevel, level2: StorageLevel) {
240-
assert(level1.useDisk == level2.useDisk)
241-
assert(level1.useMemory == level2.useMemory)
242-
assert(level1.deserialized == level2.deserialized)
243-
assert(level1.replication == level2.replication)
239+
assert(level1.useDisk === level2.useDisk)
240+
assert(level1.useMemory === level2.useMemory)
241+
assert(level1.deserialized === level2.deserialized)
242+
assert(level1.replication === level2.replication)
244243
}
245244

246245
private def assertEquals(info1: TaskInfo, info2: TaskInfo) {
247-
assert(info1.taskId == info2.taskId)
248-
assert(info1.index == info2.index)
249-
assert(info1.launchTime == info2.launchTime)
250-
assert(info1.executorId == info2.executorId)
251-
assert(info1.host == info2.host)
252-
assert(info1.taskLocality == info2.taskLocality)
253-
assert(info1.gettingResultTime == info2.gettingResultTime)
254-
assert(info1.finishTime == info2.finishTime)
255-
assert(info1.failed == info2.failed)
256-
assert(info1.serializedSize == info2.serializedSize)
246+
assert(info1.taskId === info2.taskId)
247+
assert(info1.index === info2.index)
248+
assert(info1.launchTime === info2.launchTime)
249+
assert(info1.executorId === info2.executorId)
250+
assert(info1.host === info2.host)
251+
assert(info1.taskLocality === info2.taskLocality)
252+
assert(info1.gettingResultTime === info2.gettingResultTime)
253+
assert(info1.finishTime === info2.finishTime)
254+
assert(info1.failed === info2.failed)
255+
assert(info1.serializedSize === info2.serializedSize)
257256
}
258257

259258
private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) {
260-
assert(metrics1.hostname == metrics2.hostname)
261-
assert(metrics1.executorDeserializeTime == metrics2.executorDeserializeTime)
262-
assert(metrics1.resultSize == metrics2.resultSize)
263-
assert(metrics1.jvmGCTime == metrics2.jvmGCTime)
264-
assert(metrics1.resultSerializationTime == metrics2.resultSerializationTime)
265-
assert(metrics1.memoryBytesSpilled == metrics2.memoryBytesSpilled)
266-
assert(metrics1.diskBytesSpilled == metrics2.diskBytesSpilled)
259+
assert(metrics1.hostname === metrics2.hostname)
260+
assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime)
261+
assert(metrics1.resultSize === metrics2.resultSize)
262+
assert(metrics1.jvmGCTime === metrics2.jvmGCTime)
263+
assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime)
264+
assert(metrics1.memoryBytesSpilled === metrics2.memoryBytesSpilled)
265+
assert(metrics1.diskBytesSpilled === metrics2.diskBytesSpilled)
267266
assertOptionEquals(
268267
metrics1.shuffleReadMetrics, metrics2.shuffleReadMetrics, assertShuffleReadEquals)
269268
assertOptionEquals(
@@ -272,31 +271,31 @@ class JsonProtocolSuite extends FunSuite {
272271
}
273272

274273
private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) {
275-
assert(metrics1.shuffleFinishTime == metrics2.shuffleFinishTime)
276-
assert(metrics1.totalBlocksFetched == metrics2.totalBlocksFetched)
277-
assert(metrics1.remoteBlocksFetched == metrics2.remoteBlocksFetched)
278-
assert(metrics1.localBlocksFetched == metrics2.localBlocksFetched)
279-
assert(metrics1.fetchWaitTime == metrics2.fetchWaitTime)
280-
assert(metrics1.remoteBytesRead == metrics2.remoteBytesRead)
274+
assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime)
275+
assert(metrics1.totalBlocksFetched === metrics2.totalBlocksFetched)
276+
assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched)
277+
assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched)
278+
assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime)
279+
assert(metrics1.remoteBytesRead === metrics2.remoteBytesRead)
281280
}
282281

283282
private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) {
284-
assert(metrics1.shuffleBytesWritten == metrics2.shuffleBytesWritten)
285-
assert(metrics1.shuffleWriteTime == metrics2.shuffleWriteTime)
283+
assert(metrics1.shuffleBytesWritten === metrics2.shuffleBytesWritten)
284+
assert(metrics1.shuffleWriteTime === metrics2.shuffleWriteTime)
286285
}
287286

288287
private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) {
289-
assert(bm1.executorId == bm2.executorId)
290-
assert(bm1.host == bm2.host)
291-
assert(bm1.port == bm2.port)
292-
assert(bm1.nettyPort == bm2.nettyPort)
288+
assert(bm1.executorId === bm2.executorId)
289+
assert(bm1.host === bm2.host)
290+
assert(bm1.port === bm2.port)
291+
assert(bm1.nettyPort === bm2.nettyPort)
293292
}
294293

295294
private def assertEquals(result1: JobResult, result2: JobResult) {
296295
(result1, result2) match {
297296
case (JobSucceeded, JobSucceeded) =>
298297
case (r1: JobFailed, r2: JobFailed) =>
299-
assert(r1.failedStageId == r2.failedStageId)
298+
assert(r1.failedStageId === r2.failedStageId)
300299
assertEquals(r1.exception, r2.exception)
301300
case _ => fail("Job results don't match in types!")
302301
}
@@ -307,13 +306,13 @@ class JsonProtocolSuite extends FunSuite {
307306
case (Success, Success) =>
308307
case (Resubmitted, Resubmitted) =>
309308
case (r1: FetchFailed, r2: FetchFailed) =>
310-
assert(r1.shuffleId == r2.shuffleId)
311-
assert(r1.mapId == r2.mapId)
312-
assert(r1.reduceId == r2.reduceId)
309+
assert(r1.shuffleId === r2.shuffleId)
310+
assert(r1.mapId === r2.mapId)
311+
assert(r1.reduceId === r2.reduceId)
313312
assertEquals(r1.bmAddress, r2.bmAddress)
314313
case (r1: ExceptionFailure, r2: ExceptionFailure) =>
315-
assert(r1.className == r2.className)
316-
assert(r1.description == r2.description)
314+
assert(r1.className === r2.className)
315+
assert(r1.description === r2.description)
317316
assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals)
318317
assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals)
319318
case (TaskResultLost, TaskResultLost) =>
@@ -329,13 +328,13 @@ class JsonProtocolSuite extends FunSuite {
329328
details2: Map[String, Seq[(String, String)]]) {
330329
details1.zip(details2).foreach {
331330
case ((key1, values1: Seq[(String, String)]), (key2, values2: Seq[(String, String)])) =>
332-
assert(key1 == key2)
333-
values1.zip(values2).foreach { case (v1, v2) => assert(v1 == v2) }
331+
assert(key1 === key2)
332+
values1.zip(values2).foreach { case (v1, v2) => assert(v1 === v2) }
334333
}
335334
}
336335

337336
private def assertEquals(exception1: Exception, exception2: Exception) {
338-
assert(exception1.getMessage == exception2.getMessage)
337+
assert(exception1.getMessage === exception2.getMessage)
339338
assertSeqEquals(
340339
exception1.getStackTrace,
341340
exception2.getStackTrace,
@@ -344,11 +343,11 @@ class JsonProtocolSuite extends FunSuite {
344343

345344
private def assertJsonStringEquals(json1: String, json2: String) {
346345
val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "")
347-
formatJsonString(json1) == formatJsonString(json2)
346+
formatJsonString(json1) === formatJsonString(json2)
348347
}
349348

350349
private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) {
351-
assert(seq1.length == seq2.length)
350+
assert(seq1.length === seq2.length)
352351
seq1.zip(seq2).foreach { case (t1, t2) =>
353352
assertEquals(t1, t2)
354353
}
@@ -389,11 +388,11 @@ class JsonProtocolSuite extends FunSuite {
389388
}
390389

391390
private def assertBlockEquals(b1: (BlockId, BlockStatus), b2: (BlockId, BlockStatus)) {
392-
assert(b1 == b2)
391+
assert(b1 === b2)
393392
}
394393

395394
private def assertStackTraceElementEquals(ste1: StackTraceElement, ste2: StackTraceElement) {
396-
assert(ste1 == ste2)
395+
assert(ste1 === ste2)
397396
}
398397

399398

0 commit comments

Comments
 (0)