Skip to content

Commit 4c6d0ee

Browse files
committed
Pass callbacks cleanly.
1 parent 603dce7 commit 4c6d0ee

File tree

6 files changed

+152
-73
lines changed

6 files changed

+152
-73
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.network.netty.client
19+
20+
import java.util.EventListener
21+
22+
23+
trait BlockClientListener extends EventListener {
24+
25+
def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit
26+
27+
def onFetchFailure(blockId: String, errorMsg: String): Unit
28+
29+
}

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ import org.apache.spark.Logging
3737
*
3838
* See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
3939
*
40-
* Concurrency: [[BlockFetchingClient]] is not thread safe and should not be shared.
40+
* Concurrency: thread safe and can be called from multiple threads.
4141
*/
4242
@throws[TimeoutException]
4343
private[spark]
4444
class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
4545
extends Logging {
4646

47-
val handler = new BlockFetchingClientHandler
47+
private val handler = new BlockFetchingClientHandler
4848

4949
/** Netty Bootstrap for creating the TCP connection. */
5050
private val bootstrap: Bootstrap = {
@@ -84,17 +84,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
8484
* rate of fetching; otherwise we could run out of memory.
8585
*
8686
* @param blockIds sequence of block ids to fetch.
87-
* @param blockFetchSuccessCallback callback function when a block is successfully fetched.
88-
* First argument is the block id, and second argument is the
89-
* raw data in a ByteBuffer.
90-
* @param blockFetchFailureCallback callback function when we failed to fetch any of the blocks.
91-
* First argument is the block id, and second argument is the
92-
* error message.
87+
* @param listener callback to fire on fetch success / failure.
9388
*/
94-
def fetchBlocks(
95-
blockIds: Seq[String],
96-
blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit,
97-
blockFetchFailureCallback: (String, String) => Unit): Unit = {
89+
def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
9890
// It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
9991
// It's also best to limit the number of "flush" calls since it requires system calls.
10092
// Let's concatenate the string and then call writeAndFlush once.
@@ -106,9 +98,9 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
10698
s"Sending request $blockIds to $hostname:$port"
10799
}
108100

109-
// TODO: This is not the most elegant way to handle this ...
110-
handler.blockFetchSuccessCallback = blockFetchSuccessCallback
111-
handler.blockFetchFailureCallback = blockFetchFailureCallback
101+
blockIds.foreach { blockId =>
102+
handler.addRequest(blockId, listener)
103+
}
112104

113105
val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
114106
writeFuture.addListener(new ChannelFutureListener {
@@ -120,8 +112,13 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
120112
}
121113
} else {
122114
// Fail all blocks.
123-
logError(s"Failed to send request $blockIds to $hostname:$port", future.cause)
124-
blockIds.foreach(blockFetchFailureCallback(_, future.cause.getMessage))
115+
val errorMsg =
116+
s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
117+
logError(errorMsg, future.cause)
118+
blockIds.foreach { blockId =>
119+
listener.onFetchFailure(blockId, errorMsg)
120+
handler.removeRequest(blockId)
121+
}
125122
}
126123
}
127124
})

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,39 @@ import org.apache.spark.Logging
2626
/**
2727
* Handler that processes server responses. It uses the protocol documented in
2828
* [[org.apache.spark.network.netty.server.BlockServer]].
29+
*
30+
* Concurrency: thread safe and can be called from multiple threads.
2931
*/
3032
private[client]
3133
class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {
3234

33-
var blockFetchSuccessCallback: (String, ReferenceCountedBuffer) => Unit = _
34-
var blockFetchFailureCallback: (String, String) => Unit = _
35+
/** Tracks the list of outstanding requests and their listeners on success/failure. */
36+
private val outstandingRequests = java.util.Collections.synchronizedMap {
37+
new java.util.HashMap[String, BlockClientListener]
38+
}
39+
40+
def addRequest(blockId: String, listener: BlockClientListener): Unit = {
41+
outstandingRequests.put(blockId, listener)
42+
}
43+
44+
def removeRequest(blockId: String): Unit = {
45+
outstandingRequests.remove(blockId)
46+
}
3547

3648
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
37-
logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
49+
val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
50+
logError(errorMsg, cause)
51+
52+
// Fire the failure callback for all outstanding blocks
53+
outstandingRequests.synchronized {
54+
val iter = outstandingRequests.entrySet().iterator()
55+
while (iter.hasNext) {
56+
val entry = iter.next()
57+
entry.getValue.onFetchFailure(entry.getKey, errorMsg)
58+
}
59+
outstandingRequests.clear()
60+
}
61+
3862
ctx.close()
3963
}
4064

@@ -54,10 +78,26 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
5478
in.readBytes(errorMessageBytes)
5579
val errorMsg = new String(errorMessageBytes)
5680
logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
57-
blockFetchFailureCallback(blockId, errorMsg)
81+
82+
val listener = outstandingRequests.get(blockId)
83+
if (listener == null) {
84+
// Ignore callback
85+
logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
86+
} else {
87+
outstandingRequests.remove(blockId)
88+
listener.onFetchFailure(blockId, errorMsg)
89+
}
5890
} else {
5991
logTrace(s"Received block $blockId ($blockSize B) from $server")
60-
blockFetchSuccessCallback(blockId, new ReferenceCountedBuffer(in))
92+
93+
val listener = outstandingRequests.get(blockId)
94+
if (listener == null) {
95+
// Ignore callback
96+
logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
97+
} else {
98+
outstandingRequests.remove(blockId)
99+
listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
100+
}
61101
}
62102
}
63103
}

core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.storage
1919

2020
import java.util.concurrent.LinkedBlockingQueue
21-
import org.apache.spark.network.netty.client.{LazyInitIterator, ReferenceCountedBuffer}
21+
import org.apache.spark.network.netty.client.{BlockClientListener, LazyInitIterator, ReferenceCountedBuffer}
2222

2323
import scala.collection.mutable.ArrayBuffer
2424
import scala.collection.mutable.HashSet
@@ -285,37 +285,40 @@ object BlockFetcherIterator {
285285

286286
client.fetchBlocks(
287287
blocks,
288-
(blockId: String, refBuf: ReferenceCountedBuffer) => {
289-
// Increment the reference count so the buffer won't be recycled.
290-
// TODO: This could result in memory leaks when the task is stopped due to exception
291-
// before the iterator is exhausted.
292-
refBuf.retain()
293-
val buf = refBuf.byteBuffer()
294-
val blockSize = buf.remaining()
295-
val bid = BlockId(blockId)
296-
297-
// TODO: remove code duplication between here and BlockManager.dataDeserialization.
298-
results.put(new FetchResult(bid, sizeMap(bid), () => {
299-
def createIterator: Iterator[Any] = {
300-
val stream = blockManager.wrapForCompression(bid, refBuf.inputStream())
301-
serializer.newInstance().deserializeStream(stream).asIterator
288+
new BlockClientListener {
289+
override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
290+
logError(s"Could not get block(s) from $cmId with error: $errorMsg")
291+
for ((blockId, size) <- req.blocks) {
292+
results.put(new FetchResult(blockId, -1, null))
302293
}
303-
new LazyInitIterator(createIterator) {
304-
// Release the buffer when we are done traversing it.
305-
override def close(): Unit = refBuf.release()
306-
}
307-
}))
308-
309-
readMetrics.synchronized {
310-
readMetrics.remoteBytesRead += blockSize
311-
readMetrics.remoteBlocksFetched += 1
312294
}
313-
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
314-
},
315-
(blockId: String, errorMsg: String) => {
316-
logError(s"Could not get block(s) from $cmId with error: $errorMsg")
317-
for ((blockId, size) <- req.blocks) {
318-
results.put(new FetchResult(blockId, -1, null))
295+
296+
override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
297+
// Increment the reference count so the buffer won't be recycled.
298+
// TODO: This could result in memory leaks when the task is stopped due to exception
299+
// before the iterator is exhausted.
300+
data.retain()
301+
val buf = data.byteBuffer()
302+
val blockSize = buf.remaining()
303+
val bid = BlockId(blockId)
304+
305+
// TODO: remove code duplication between here and BlockManager.dataDeserialization.
306+
results.put(new FetchResult(bid, sizeMap(bid), () => {
307+
def createIterator: Iterator[Any] = {
308+
val stream = blockManager.wrapForCompression(bid, data.inputStream())
309+
serializer.newInstance().deserializeStream(stream).asIterator
310+
}
311+
new LazyInitIterator(createIterator) {
312+
// Release the buffer when we are done traversing it.
313+
override def close(): Unit = data.release()
314+
}
315+
}))
316+
317+
readMetrics.synchronized {
318+
readMetrics.remoteBytesRead += blockSize
319+
readMetrics.remoteBlocksFetched += 1
320+
}
321+
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
319322
}
320323
}
321324
)

core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import io.netty.buffer.{ByteBufUtil, Unpooled}
2929
import org.scalatest.{BeforeAndAfterAll, FunSuite}
3030

3131
import org.apache.spark.SparkConf
32-
import org.apache.spark.network.netty.client.{ReferenceCountedBuffer, BlockFetchingClientFactory}
32+
import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory}
3333
import org.apache.spark.network.netty.server.BlockServer
3434
import org.apache.spark.storage.{FileSegment, BlockDataProvider}
3535

@@ -99,15 +99,18 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
9999

100100
client.fetchBlocks(
101101
blockIds,
102-
(blockId, buf) => {
103-
receivedBlockIds.add(blockId)
104-
buf.retain()
105-
receivedBuffers.add(buf)
106-
sem.release()
107-
},
108-
(blockId, errorMsg) => {
109-
errorBlockIds.add(blockId)
110-
sem.release()
102+
new BlockClientListener {
103+
override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
104+
errorBlockIds.add(blockId)
105+
sem.release()
106+
}
107+
108+
override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
109+
receivedBlockIds.add(blockId)
110+
data.retain()
111+
receivedBuffers.add(data)
112+
sem.release()
113+
}
111114
}
112115
)
113116
if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {

core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,17 @@ class BlockFetchingClientHandlerSuite extends FunSuite {
3535
var parsedBlockId: String = ""
3636
var parsedBlockData: String = ""
3737
val handler = new BlockFetchingClientHandler
38-
handler.blockFetchSuccessCallback = (bid, refCntBuf) => {
39-
parsedBlockId = bid
40-
val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
41-
refCntBuf.byteBuffer().get(bytes)
42-
parsedBlockData = new String(bytes)
43-
}
38+
handler.addRequest(blockId,
39+
new BlockClientListener {
40+
override def onFetchFailure(blockId: String, errorMsg: String): Unit = ???
41+
override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = {
42+
parsedBlockId = bid
43+
val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
44+
refCntBuf.byteBuffer().get(bytes)
45+
parsedBlockData = new String(bytes)
46+
}
47+
}
48+
)
4449

4550
val channel = new EmbeddedChannel(handler)
4651
val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
@@ -65,11 +70,13 @@ class BlockFetchingClientHandlerSuite extends FunSuite {
6570
var parsedBlockId: String = ""
6671
var parsedErrorMsg: String = ""
6772
val handler = new BlockFetchingClientHandler
68-
handler.blockFetchFailureCallback = (bid, msg) => {
69-
parsedBlockId = bid
70-
parsedErrorMsg = msg
71-
}
72-
73+
handler.addRequest(blockId, new BlockClientListener {
74+
override def onFetchFailure(bid: String, msg: String) ={
75+
parsedBlockId = bid
76+
parsedErrorMsg = msg
77+
}
78+
override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ???
79+
})
7380
val channel = new EmbeddedChannel(handler)
7481
val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
7582
buf.putInt(totalLength)

0 commit comments

Comments
 (0)