Skip to content

Commit e2b8c4a

Browse files
committed
Modify to propagete error using ConnectionManager
1 parent 30b8d36 commit e2b8c4a

File tree

6 files changed

+46
-19
lines changed

6 files changed

+46
-19
lines changed

core/src/main/scala/org/apache/spark/network/BufferMessage.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
4848
val security = if (isSecurityNeg) 1 else 0
4949
if (size == 0 && !gotChunkForSendingOnce) {
5050
val newChunk = new MessageChunk(
51-
new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
51+
new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
5252
gotChunkForSendingOnce = true
5353
return Some(newChunk)
5454
}
@@ -66,7 +66,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
6666
}
6767
buffer.position(buffer.position + newBuffer.remaining)
6868
val newChunk = new MessageChunk(new MessageChunkHeader(
69-
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
69+
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
7070
gotChunkForSendingOnce = true
7171
return Some(newChunk)
7272
}
@@ -88,7 +88,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
8888
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
8989
buffer.position(buffer.position + newBuffer.remaining)
9090
val newChunk = new MessageChunk(new MessageChunkHeader(
91-
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
91+
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
9292
return Some(newChunk)
9393
}
9494
None

core/src/main/scala/org/apache/spark/network/ConnectionManager.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,6 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
674674
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
675675
}
676676
}
677-
678677
sendMessage(connectionManagerId, ackMessage.getOrElse {
679678
Message.createBufferMessage(bufferMessage.id)
680679
})

core/src/main/scala/org/apache/spark/network/Message.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
2828
var startTime = -1L
2929
var finishTime = -1L
3030
var isSecurityNeg = false
31+
var hasError = false
3132

3233
def size: Int
3334

@@ -87,6 +88,7 @@ private[spark] object Message {
8788
case BUFFER_MESSAGE => new BufferMessage(header.id,
8889
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
8990
}
91+
newMessage.hasError = header.hasError
9092
newMessage.senderAddress = header.address
9193
newMessage
9294
}

core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
2727
val totalSize: Int,
2828
val chunkSize: Int,
2929
val other: Int,
30+
val hasError: Boolean,
3031
val securityNeg: Int,
3132
val address: InetSocketAddress) {
3233
lazy val buffer = {
@@ -41,6 +42,13 @@ private[spark] class MessageChunkHeader(
4142
putInt(totalSize).
4243
putInt(chunkSize).
4344
putInt(other).
45+
put{
46+
if (hasError) {
47+
1.asInstanceOf[Byte]
48+
} else {
49+
0.asInstanceOf[Byte]
50+
}
51+
}.
4452
putInt(securityNeg).
4553
putInt(ip.size).
4654
put(ip).
@@ -56,7 +64,7 @@ private[spark] class MessageChunkHeader(
5664

5765

5866
private[spark] object MessageChunkHeader {
59-
val HEADER_SIZE = 44
67+
val HEADER_SIZE = 45
6068

6169
def create(buffer: ByteBuffer): MessageChunkHeader = {
6270
if (buffer.remaining != HEADER_SIZE) {
@@ -67,13 +75,20 @@ private[spark] object MessageChunkHeader {
6775
val totalSize = buffer.getInt()
6876
val chunkSize = buffer.getInt()
6977
val other = buffer.getInt()
78+
val hasError = {
79+
if (buffer.get() == 0) {
80+
false
81+
} else {
82+
true
83+
}
84+
}
7085
val securityNeg = buffer.getInt()
7186
val ipSize = buffer.getInt()
7287
val ipBytes = new Array[Byte](ipSize)
7388
buffer.get(ipBytes)
7489
val ip = InetAddress.getByAddress(ipBytes)
7590
val port = buffer.getInt()
76-
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
91+
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
7792
new InetSocketAddress(ip, port))
7893
}
7994
}

core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,25 @@ object BlockFetcherIterator {
122122
future.onSuccess {
123123
case Some(message) => {
124124
val bufferMessage = message.asInstanceOf[BufferMessage]
125-
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
126-
for (blockMessage <- blockMessageArray) {
127-
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
128-
throw new SparkException(
129-
"Unexpected message " + blockMessage.getType + " received from " + cmId)
125+
if (bufferMessage.hasError) {
126+
logError("Could not get block(s) from " + cmId)
127+
for ((blockId, size) <- req.blocks) {
128+
results.put(new FetchResult(blockId, -1, null))
129+
}
130+
} else {
131+
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
132+
for (blockMessage <- blockMessageArray) {
133+
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
134+
throw new SparkException(
135+
"Unexpected message " + blockMessage.getType + " received from " + cmId)
136+
}
137+
val blockId = blockMessage.getId
138+
val networkSize = blockMessage.getData.limit()
139+
results.put(new FetchResult(blockId, sizeMap(blockId),
140+
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
141+
_remoteBytesRead += networkSize
142+
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
130143
}
131-
val blockId = blockMessage.getId
132-
val networkSize = blockMessage.getData.limit()
133-
results.put(new FetchResult(blockId, sizeMap(blockId),
134-
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
135-
_remoteBytesRead += networkSize
136-
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
137144
}
138145
}
139146
case None => {

core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,12 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
4444
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
4545
Some(new BlockMessageArray(responseMessages).toBufferMessage)
4646
} catch {
47-
case e: Exception => logError("Exception handling buffer message", e)
48-
None
47+
case e: Exception => {
48+
logError("Exception handling buffer message", e)
49+
val errorMessage = Message.createBufferMessage(msg.id)
50+
errorMessage.hasError = true
51+
Some(errorMessage)
52+
}
4953
}
5054
}
5155
case otherMessage: Any => {

0 commit comments

Comments
 (0)