diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e03198b0169f4..e2fc53890916a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -1036,40 +1036,49 @@ final class ShuffleBlockFetcherIterator( address: BlockManagerId, blockId: BlockId): String = { logInfo("Start corruption diagnosis.") - val startTimeNs = System.nanoTime() - assert(blockId.isInstanceOf[ShuffleBlockId], s"Expected ShuffleBlockId, but got $blockId") - val shuffleBlock = blockId.asInstanceOf[ShuffleBlockId] - val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) - // consume the remaining data to calculate the checksum - var cause: Cause = null - try { - while (checkedIn.read(buffer) != -1) {} - val checksum = checkedIn.getChecksum.getValue - cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, - shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, - checksumAlgorithm) - } catch { - case e: Exception => - logWarning("Unable to diagnose the corruption cause of the corrupted block", e) - cause = Cause.UNKNOWN_ISSUE - } - val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) - val diagnosisResponse = cause match { - case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => - s"Block $blockId is corrupted but corruption diagnosis failed due to " + - s"unsupported checksum algorithm: $checksumAlgorithm" + blockId match { + case shuffleBlock: ShuffleBlockId => + val startTimeNs = System.nanoTime() + val buffer = new Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER) + // consume the remaining data to calculate the checksum + var cause: Cause = null + try { + while (checkedIn.read(buffer) != -1) {} + val checksum = checkedIn.getChecksum.getValue + cause = shuffleClient.diagnoseCorruption(address.host, address.port, address.executorId, + shuffleBlock.shuffleId, shuffleBlock.mapId, shuffleBlock.reduceId, checksum, + checksumAlgorithm) + } catch { + case e: Exception => + logWarning("Unable to diagnose the corruption cause of the corrupted block", e) + cause = Cause.UNKNOWN_ISSUE + } + val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs) + val diagnosisResponse = cause match { + case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM => + s"Block $blockId is corrupted but corruption diagnosis failed due to " + + s"unsupported checksum algorithm: $checksumAlgorithm" - case Cause.CHECKSUM_VERIFY_PASS => - s"Block $blockId is corrupted but checksum verification passed" + case Cause.CHECKSUM_VERIFY_PASS => + s"Block $blockId is corrupted but checksum verification passed" - case Cause.UNKNOWN_ISSUE => - s"Block $blockId is corrupted but the cause is unknown" + case Cause.UNKNOWN_ISSUE => + s"Block $blockId is corrupted but the cause is unknown" - case otherCause => - s"Block $blockId is corrupted due to $otherCause" + case otherCause => + s"Block $blockId is corrupted due to $otherCause" + } + logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") + diagnosisResponse + case shuffleBlockChunk: ShuffleBlockChunkId => + // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle + val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted but corruption " + + s"diagnosis is skipped due to lack of shuffle checksum support for push-based shuffle." + logWarning(diagnosisResponse) + diagnosisResponse + case unexpected: BlockId => + throw new IllegalArgumentException(s"Unexpected type of BlockId, $unexpected") } - logInfo(s"Finished corruption diagnosis in $duration ms. $diagnosisResponse") - diagnosisResponse } def toCompletionIterator: Iterator[(BlockId, InputStream)] = {