Skip to content

Commit 7399c6b

Browse files
committed
Merge remote-tracking branch 'origin/pr/1490' into connection-manager-fixes
Conflicts: core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
2 parents 78f2af5 + ee91bb7 commit 7399c6b

File tree

9 files changed

+297
-42
lines changed

9 files changed

+297
-42
lines changed

core/src/main/scala/org/apache/spark/network/BufferMessage.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
4848
val security = if (isSecurityNeg) 1 else 0
4949
if (size == 0 && !gotChunkForSendingOnce) {
5050
val newChunk = new MessageChunk(
51-
new MessageChunkHeader(typ, id, 0, 0, ackId, security, senderAddress), null)
51+
new MessageChunkHeader(typ, id, 0, 0, ackId, hasError, security, senderAddress), null)
5252
gotChunkForSendingOnce = true
5353
return Some(newChunk)
5454
}
@@ -66,7 +66,8 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
6666
}
6767
buffer.position(buffer.position + newBuffer.remaining)
6868
val newChunk = new MessageChunk(new MessageChunkHeader(
69-
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
69+
typ, id, size, newBuffer.remaining, ackId,
70+
hasError, security, senderAddress), newBuffer)
7071
gotChunkForSendingOnce = true
7172
return Some(newChunk)
7273
}
@@ -88,7 +89,7 @@ class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId:
8889
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
8990
buffer.position(buffer.position + newBuffer.remaining)
9091
val newChunk = new MessageChunk(new MessageChunkHeader(
91-
typ, id, size, newBuffer.remaining, ackId, security, senderAddress), newBuffer)
92+
typ, id, size, newBuffer.remaining, ackId, hasError, security, senderAddress), newBuffer)
9293
return Some(newChunk)
9394
}
9495
None

core/src/main/scala/org/apache/spark/network/ConnectionManager.scala

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -660,27 +660,37 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
660660
sentMessageStatus.markDone()
661661
}
662662
} else {
663-
val ackMessage = if (onReceiveCallback != null) {
664-
logDebug("Calling back")
665-
onReceiveCallback(bufferMessage, connectionManagerId)
666-
} else {
667-
logDebug("Not calling back as callback is null")
668-
None
669-
}
663+
var ackMessage : Option[Message] = None
664+
try {
665+
ackMessage = if (onReceiveCallback != null) {
666+
logDebug("Calling back")
667+
onReceiveCallback(bufferMessage, connectionManagerId)
668+
} else {
669+
logDebug("Not calling back as callback is null")
670+
None
671+
}
670672

671-
if (ackMessage.isDefined) {
672-
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
673-
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
674-
+ ackMessage.get.getClass)
675-
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
676-
logDebug("Response to " + bufferMessage + " does not have ack id set")
677-
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
673+
if (ackMessage.isDefined) {
674+
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
675+
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
676+
+ ackMessage.get.getClass)
677+
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
678+
logDebug("Response to " + bufferMessage + " does not have ack id set")
679+
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
680+
}
678681
}
682+
} catch {
683+
case e: Exception => {
684+
logError(s"Exception was thrown during processing message", e)
685+
val m = Message.createBufferMessage(bufferMessage.id)
686+
m.hasError = true
687+
ackMessage = Some(m)
688+
}
689+
} finally {
690+
sendMessage(connectionManagerId, ackMessage.getOrElse {
691+
Message.createBufferMessage(bufferMessage.id)
692+
})
679693
}
680-
681-
sendMessage(connectionManagerId, ackMessage.getOrElse {
682-
Message.createBufferMessage(bufferMessage.id)
683-
})
684694
}
685695
}
686696
case _ => throw new Exception("Unknown type message received")

core/src/main/scala/org/apache/spark/network/Message.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
2828
var startTime = -1L
2929
var finishTime = -1L
3030
var isSecurityNeg = false
31+
var hasError = false
3132

3233
def size: Int
3334

@@ -87,6 +88,7 @@ private[spark] object Message {
8788
case BUFFER_MESSAGE => new BufferMessage(header.id,
8889
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
8990
}
91+
newMessage.hasError = header.hasError
9092
newMessage.senderAddress = header.address
9193
newMessage
9294
}

core/src/main/scala/org/apache/spark/network/MessageChunkHeader.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ private[spark] class MessageChunkHeader(
2727
val totalSize: Int,
2828
val chunkSize: Int,
2929
val other: Int,
30+
val hasError: Boolean,
3031
val securityNeg: Int,
3132
val address: InetSocketAddress) {
3233
lazy val buffer = {
@@ -41,6 +42,7 @@ private[spark] class MessageChunkHeader(
4142
putInt(totalSize).
4243
putInt(chunkSize).
4344
putInt(other).
45+
put(if (hasError) 1.asInstanceOf[Byte] else 0.asInstanceOf[Byte]).
4446
putInt(securityNeg).
4547
putInt(ip.size).
4648
put(ip).
@@ -56,7 +58,7 @@ private[spark] class MessageChunkHeader(
5658

5759

5860
private[spark] object MessageChunkHeader {
59-
val HEADER_SIZE = 44
61+
val HEADER_SIZE = 45
6062

6163
def create(buffer: ByteBuffer): MessageChunkHeader = {
6264
if (buffer.remaining != HEADER_SIZE) {
@@ -67,13 +69,14 @@ private[spark] object MessageChunkHeader {
6769
val totalSize = buffer.getInt()
6870
val chunkSize = buffer.getInt()
6971
val other = buffer.getInt()
72+
val hasError = buffer.get() != 0
7073
val securityNeg = buffer.getInt()
7174
val ipSize = buffer.getInt()
7275
val ipBytes = new Array[Byte](ipSize)
7376
buffer.get(ipBytes)
7477
val ip = InetAddress.getByAddress(ipBytes)
7578
val port = buffer.getInt()
76-
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, securityNeg,
79+
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, hasError, securityNeg,
7780
new InetSocketAddress(ip, port))
7881
}
7982
}

core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,25 @@ object BlockFetcherIterator {
121121
future.onSuccess {
122122
case Some(message) => {
123123
val bufferMessage = message.asInstanceOf[BufferMessage]
124-
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
125-
for (blockMessage <- blockMessageArray) {
126-
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
127-
throw new SparkException(
128-
"Unexpected message " + blockMessage.getType + " received from " + cmId)
124+
if (bufferMessage.hasError) {
125+
logError("Could not get block(s) from " + cmId)
126+
for ((blockId, size) <- req.blocks) {
127+
results.put(new FetchResult(blockId, -1, null))
128+
}
129+
} else {
130+
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
131+
for (blockMessage <- blockMessageArray) {
132+
if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
133+
throw new SparkException(
134+
"Unexpected message " + blockMessage.getType + " received from " + cmId)
135+
}
136+
val blockId = blockMessage.getId
137+
val networkSize = blockMessage.getData.limit()
138+
results.put(new FetchResult(blockId, sizeMap(blockId),
139+
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
140+
_remoteBytesRead += networkSize
141+
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
129142
}
130-
val blockId = blockMessage.getId
131-
val networkSize = blockMessage.getData.limit()
132-
results.put(new FetchResult(blockId, sizeMap(blockId),
133-
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
134-
_remoteBytesRead += networkSize
135-
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
136143
}
137144
}
138145
case None => {

core/src/main/scala/org/apache/spark/storage/BlockManagerWorker.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,19 @@ private[spark] class BlockManagerWorker(val blockManager: BlockManager) extends
4444
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
4545
Some(new BlockMessageArray(responseMessages).toBufferMessage)
4646
} catch {
47-
case e: Exception => logError("Exception handling buffer message", e)
48-
None
47+
case e: Exception => {
48+
logError("Exception handling buffer message", e)
49+
val errorMessage = Message.createBufferMessage(msg.id)
50+
errorMessage.hasError = true
51+
Some(errorMessage)
52+
}
4953
}
5054
}
5155
case otherMessage: Any => {
5256
logError("Unknown type message received: " + otherMessage)
53-
None
57+
val errorMessage = Message.createBufferMessage(msg.id)
58+
errorMessage.hasError = true
59+
Some(errorMessage)
5460
}
5561
}
5662
}

core/src/test/scala/org/apache/spark/network/ConnectionManagerSuite.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,31 @@ class ConnectionManagerSuite extends FunSuite {
223223
managerServer.stop()
224224
}
225225

226+
test("Ack error message") {
227+
val conf = new SparkConf
228+
conf.set("spark.authenticate", "false")
229+
val securityManager = new SecurityManager(conf)
230+
val manager = new ConnectionManager(0, conf, securityManager)
231+
val managerServer = new ConnectionManager(0, conf, securityManager)
232+
managerServer.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
233+
throw new Exception
234+
})
226235

236+
val size = 10 * 1024 * 1024
237+
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
238+
buffer.flip
239+
val bufferMessage = Message.createBufferMessage(buffer)
240+
241+
val future = manager.sendMessageReliably(managerServer.id, bufferMessage)
242+
243+
val message = Await.result(future, 1 second)
244+
assert(message.isDefined)
245+
assert(message.get.hasError)
246+
247+
manager.stop()
248+
managerServer.stop()
249+
250+
}
227251

228252
}
229253

core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,21 @@
1717

1818
package org.apache.spark.storage
1919

20+
import java.nio.ByteBuffer
21+
22+
import scala.collection.mutable.ArrayBuffer
23+
import scala.concurrent.future
24+
import scala.concurrent.ExecutionContext.Implicits.global
25+
2026
import org.scalatest.{FunSuite, Matchers}
21-
import org.scalatest.PrivateMethodTester._
2227

2328
import org.mockito.Mockito._
2429
import org.mockito.Matchers.{any, eq => meq}
2530
import org.mockito.stubbing.Answer
2631
import org.mockito.invocation.InvocationOnMock
2732

28-
import org.apache.spark._
2933
import org.apache.spark.storage.BlockFetcherIterator._
30-
import org.apache.spark.network.{ConnectionManager, ConnectionManagerId,
31-
Message}
34+
import org.apache.spark.network.{ConnectionManager, Message}
3235

3336
class BlockFetcherIteratorSuite extends FunSuite with Matchers {
3437

@@ -137,4 +140,95 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
137140
assert(iterator.next._2.isDefined, "All elements should be defined but 5th element is not actually defined")
138141
}
139142

143+
test("block fetch from remote fails using BasicBlockFetcherIterator") {
144+
val blockManager = mock(classOf[BlockManager])
145+
val connManager = mock(classOf[ConnectionManager])
146+
when(blockManager.connectionManager).thenReturn(connManager)
147+
148+
val f = future {
149+
val message = Message.createBufferMessage(0)
150+
message.hasError = true
151+
val someMessage = Some(message)
152+
someMessage
153+
}
154+
when(connManager.sendMessageReliably(any(),
155+
any())).thenReturn(f)
156+
when(blockManager.futureExecContext).thenReturn(global)
157+
158+
when(blockManager.blockManagerId).thenReturn(
159+
BlockManagerId("test-client", "test-client", 1, 0))
160+
when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
161+
162+
val blId1 = ShuffleBlockId(0,0,0)
163+
val blId2 = ShuffleBlockId(0,1,0)
164+
val bmId = BlockManagerId("test-server", "test-server",1 , 0)
165+
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
166+
(bmId, Seq((blId1, 1L), (blId2, 1L)))
167+
)
168+
169+
val iterator = new BasicBlockFetcherIterator(blockManager,
170+
blocksByAddress, null)
171+
172+
iterator.initialize()
173+
iterator.foreach{
174+
case (_, r) => {
175+
(!r.isDefined) should be(true)
176+
}
177+
}
178+
}
179+
180+
test("block fetch from remote succeed using BasicBlockFetcherIterator") {
181+
val blockManager = mock(classOf[BlockManager])
182+
val connManager = mock(classOf[ConnectionManager])
183+
when(blockManager.connectionManager).thenReturn(connManager)
184+
185+
val blId1 = ShuffleBlockId(0,0,0)
186+
val blId2 = ShuffleBlockId(0,1,0)
187+
val buf1 = ByteBuffer.allocate(4)
188+
val buf2 = ByteBuffer.allocate(4)
189+
buf1.putInt(1)
190+
buf1.flip()
191+
buf2.putInt(1)
192+
buf2.flip()
193+
val blockMessage1 = BlockMessage.fromGotBlock(GotBlock(blId1, buf1))
194+
val blockMessage2 = BlockMessage.fromGotBlock(GotBlock(blId2, buf2))
195+
val blockMessageArray = new BlockMessageArray(
196+
Seq(blockMessage1, blockMessage2))
197+
198+
val bufferMessage = blockMessageArray.toBufferMessage
199+
val buffer = ByteBuffer.allocate(bufferMessage.size)
200+
val arrayBuffer = new ArrayBuffer[ByteBuffer]
201+
bufferMessage.buffers.foreach{ b =>
202+
buffer.put(b)
203+
}
204+
buffer.flip()
205+
arrayBuffer += buffer
206+
207+
val someMessage = Some(Message.createBufferMessage(arrayBuffer))
208+
209+
val f = future {
210+
someMessage
211+
}
212+
when(connManager.sendMessageReliably(any(),
213+
any())).thenReturn(f)
214+
when(blockManager.futureExecContext).thenReturn(global)
215+
216+
when(blockManager.blockManagerId).thenReturn(
217+
BlockManagerId("test-client", "test-client", 1, 0))
218+
when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024)
219+
220+
val bmId = BlockManagerId("test-server", "test-server",1 , 0)
221+
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
222+
(bmId, Seq((blId1, 1L), (blId2, 1L)))
223+
)
224+
225+
val iterator = new BasicBlockFetcherIterator(blockManager,
226+
blocksByAddress, null)
227+
iterator.initialize()
228+
iterator.foreach{
229+
case (_, r) => {
230+
(r.isDefined) should be(true)
231+
}
232+
}
233+
}
140234
}

0 commit comments

Comments
 (0)