Skip to content

Commit 6ce3f3c

Browse files
committed
Added some basic test cases.
1 parent 47f7ce0 commit 6ce3f3c

File tree

12 files changed

+255
-152
lines changed

12 files changed

+255
-152
lines changed

core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,26 @@ import org.apache.spark.SparkConf
2222
/**
2323
* A central location that tracks all the settings we exposed to users.
2424
*/
25+
private[spark]
2526
class NettyConfig(conf: SparkConf) {
2627

28+
/** Port the server listens on. Default to a random port. */
29+
private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0)
30+
2731
/** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
28-
val ioMode = conf.get("spark.shuffle.io.mode", "auto").toLowerCase
32+
private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "auto").toLowerCase
2933

3034
/** Connect timeout in secs. Default 60 secs. */
31-
val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
35+
private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
3236

3337
/**
3438
* Percentage of the desired amount of time spent for I/O in the child event loops.
3539
* Only applicable in nio and epoll.
3640
*/
37-
val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
41+
private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
3842

3943
/** Requested maximum length of the queue of incoming connections. */
40-
val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
44+
private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
4145

4246
/**
4347
* Receive buffer size (SO_RCVBUF).
@@ -46,8 +50,10 @@ class NettyConfig(conf: SparkConf) {
4650
* Assuming latency = 1ms, network_bandwidth = 10Gbps
4751
* buffer size should be ~ 1.25MB
4852
*/
49-
val receiveBuf: Option[Int] = conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
53+
private[netty] val receiveBuf: Option[Int] =
54+
conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
5055

5156
/** Send buffer size (SO_SNDBUF). */
52-
val sendBuf: Option[Int] = conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
57+
private[netty] val sendBuf: Option[Int] =
58+
conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
5359
}

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,5 @@ class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String,
129129
cf.channel().closeFuture().sync()
130130
}
131131

132-
def close(): Unit = {
133-
// TODO: Should we ever close the client? Probably ...
134-
}
132+
def close(): Unit = cf.channel().close()
135133
}

core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717

1818
package org.apache.spark.network.netty.client
1919

20-
import java.net.InetSocketAddress
21-
2220
import io.netty.buffer.ByteBuf
2321
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
2422

@@ -43,10 +41,7 @@ class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] wi
4341
val blockId = new String(blockIdBytes)
4442
val blockLen = math.abs(totalLen) - blockIdLen - 4
4543

46-
def server = {
47-
val remoteAddr = ctx.channel.remoteAddress.asInstanceOf[InetSocketAddress]
48-
remoteAddr.getHostName + ":" + remoteAddr.getPort
49-
}
44+
def server = ctx.channel.remoteAddress.toString
5045

5146
// totalLen is negative when it is an error message.
5247
if (totalLen < 0) {

core/src/main/scala/org/apache/spark/network/netty/client/ClientTester.scala

Lines changed: 0 additions & 45 deletions
This file was deleted.

core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,27 @@ import java.net.InetSocketAddress
2121

2222
import io.netty.bootstrap.ServerBootstrap
2323
import io.netty.buffer.PooledByteBufAllocator
24-
import io.netty.channel.socket.SocketChannel
25-
import io.netty.channel.{ChannelInitializer, ChannelOption, ChannelFuture}
26-
import io.netty.channel.epoll.{EpollServerSocketChannel, EpollEventLoopGroup}
24+
import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
25+
import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
2726
import io.netty.channel.nio.NioEventLoopGroup
2827
import io.netty.channel.oio.OioEventLoopGroup
28+
import io.netty.channel.socket.SocketChannel
2929
import io.netty.channel.socket.nio.NioServerSocketChannel
3030
import io.netty.channel.socket.oio.OioServerSocketChannel
3131
import io.netty.handler.codec.LineBasedFrameDecoder
3232
import io.netty.handler.codec.string.StringDecoder
3333
import io.netty.util.CharsetUtil
3434

35-
import org.apache.spark.{SparkConf, Logging}
36-
import org.apache.spark.network.netty.{PathResolver, NettyConfig}
35+
import org.apache.spark.{Logging, SparkConf}
36+
import org.apache.spark.network.netty.NettyConfig
37+
import org.apache.spark.storage.BlockDataProvider
3738
import org.apache.spark.util.Utils
3839

40+
3941
// TODO: Remove dependency on BlockId. This layer should not be coupled with storage.
4042

4143
// TODO: PathResolver is not general enough. It only works for on-disk blocks.
4244

43-
// TODO: Allow user-configured port
4445

4546
/**
4647
* Server for serving Spark data blocks.
@@ -58,15 +59,18 @@ import org.apache.spark.util.Utils
5859
*
5960
*/
6061
private[spark]
61-
class BlockServer(conf: NettyConfig, pResolver: PathResolver) extends Logging {
62+
class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging {
6263

63-
def this(sparkConf: SparkConf, pResolver: PathResolver) = {
64-
this(new NettyConfig(sparkConf), pResolver)
64+
def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = {
65+
this(new NettyConfig(sparkConf), dataProvider)
6566
}
6667

6768
def port: Int = _port
6869

69-
private var _port: Int = 0
70+
def hostName: String = _hostName
71+
72+
private var _port: Int = conf.serverPort
73+
private var _hostName: String = ""
7074
private var bootstrap: ServerBootstrap = _
7175
private var channelFuture: ChannelFuture = _
7276

@@ -134,7 +138,7 @@ class BlockServer(conf: NettyConfig, pResolver: PathResolver) extends Logging {
134138
.addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
135139
.addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
136140
.addLast("blockHeaderEncoder", new BlockHeaderEncoder)
137-
.addLast("handler", new BlockServerHandler(pResolver))
141+
.addLast("handler", new BlockServerHandler(dataProvider))
138142
}
139143
})
140144

@@ -143,6 +147,7 @@ class BlockServer(conf: NettyConfig, pResolver: PathResolver) extends Logging {
143147

144148
val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
145149
_port = addr.getPort
150+
_hostName = addr.getHostName
146151
}
147152

148153
/** Shutdown the server. */

core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
package org.apache.spark.network.netty.server
1919

2020
import java.io.FileInputStream
21-
import java.net.InetSocketAddress
21+
import java.nio.ByteBuffer
2222
import java.nio.channels.FileChannel
2323

24+
import io.netty.buffer.Unpooled
2425
import io.netty.channel._
2526

2627
import org.apache.spark.Logging
27-
import org.apache.spark.network.netty.PathResolver
28-
import org.apache.spark.storage.BlockId
28+
import org.apache.spark.storage.{FileSegment, BlockDataProvider}
2929

3030

3131
/**
@@ -35,15 +35,11 @@ import org.apache.spark.storage.BlockId
3535
* so channelRead0 is called once per line (i.e. per block id).
3636
*/
3737
private[server]
38-
class BlockServerHandler(p: PathResolver)
38+
class BlockServerHandler(dataProvider: BlockDataProvider)
3939
extends SimpleChannelInboundHandler[String] with Logging {
4040

4141
override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
42-
// client in the form of hostname:port
43-
val client = {
44-
val remoteAddr = ctx.channel.remoteAddress.asInstanceOf[InetSocketAddress]
45-
remoteAddr.getHostName + ":" + remoteAddr.getPort
46-
}
42+
def client = ctx.channel.remoteAddress.toString
4743

4844
// A helper function to send error message back to the client.
4945
def respondWithError(error: String): Unit = {
@@ -60,48 +56,80 @@ class BlockServerHandler(p: PathResolver)
6056
)
6157
}
6258

63-
logTrace(s"Received request from $client to fetch block $blockId")
59+
def writeFileSegment(segment: FileSegment): Unit = {
60+
// Send error message back if the block is too large. Even though we are capable of sending
61+
// large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
62+
// Once we fixed the receiving end to be able to process large blocks, this should be removed.
63+
// Also make sure we update BlockHeaderEncoder to support length > 2G.
6464

65-
var fileChannel: FileChannel = null
66-
var offset: Long = 0
67-
var blockSize: Long = 0
65+
// See [[BlockHeaderEncoder]] for the way length is encoded.
66+
if (segment.length + blockId.length + 4 > Int.MaxValue) {
67+
respondWithError(s"Block $blockId size ($segment.length) greater than 2G")
68+
return
69+
}
6870

69-
// First make sure we can find the block. If not, send error back to the user.
70-
try {
71-
val segment = p.getBlockLocation(BlockId(blockId))
72-
fileChannel = new FileInputStream(segment.file).getChannel
73-
offset = segment.offset
74-
blockSize = segment.length
75-
} catch {
76-
case e: Exception =>
77-
logError(s"Error opening block $blockId for request from $client", e)
78-
blockSize = -1
79-
respondWithError(e.getMessage)
80-
}
71+
var fileChannel: FileChannel = null
72+
try {
73+
fileChannel = new FileInputStream(segment.file).getChannel
74+
} catch {
75+
case e: Exception =>
76+
logError(
77+
s"Error opening channel for $blockId in ${segment.file} for request from $client", e)
78+
respondWithError(e.getMessage)
79+
}
80+
81+
// Found the block. Send it back.
82+
if (fileChannel != null) {
83+
// Write the header and block data. In the case of failures, the listener on the block data
84+
// write should close the connection.
85+
ctx.write(new BlockHeader(segment.length.toInt, blockId))
8186

82-
// Send error message back if the block is too large. Even though we are capable of sending
83-
// large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
84-
// Once we fixed the receiving end to be able to process large blocks, this should be removed.
85-
// Also make sure we update BlockHeaderEncoder to support length > 2G.
86-
if (blockSize > Int.MaxValue) {
87-
respondWithError(s"Block $blockId size ($blockSize) greater than 2G")
87+
val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length)
88+
ctx.writeAndFlush(region).addListener(new ChannelFutureListener {
89+
override def operationComplete(future: ChannelFuture) {
90+
if (future.isSuccess) {
91+
logTrace(s"Sent block $blockId (${segment.length} B) back to $client")
92+
} else {
93+
logError(s"Error sending block $blockId to $client; closing connection", future.cause)
94+
ctx.close()
95+
}
96+
}
97+
})
98+
}
8899
}
89100

90-
// Found the block. Send it back.
91-
if (fileChannel != null && blockSize >= 0) {
92-
val listener = new ChannelFutureListener {
101+
def writeByteBuffer(buf: ByteBuffer): Unit = {
102+
ctx.write(new BlockHeader(buf.remaining, blockId))
103+
ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener {
93104
override def operationComplete(future: ChannelFuture) {
94105
if (future.isSuccess) {
95-
logTrace(s"Sent block $blockId ($blockSize B) back to $client")
106+
logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client")
96107
} else {
97108
logError(s"Error sending block $blockId to $client; closing connection", future.cause)
98109
ctx.close()
99110
}
100111
}
101-
}
102-
val region = new DefaultFileRegion(fileChannel, offset, blockSize)
103-
ctx.writeAndFlush(new BlockHeader(blockSize.toInt, blockId)).addListener(listener)
104-
ctx.writeAndFlush(region).addListener(listener)
112+
})
113+
}
114+
115+
logTrace(s"Received request from $client to fetch block $blockId")
116+
117+
var blockData: Either[FileSegment, ByteBuffer] = null
118+
119+
// First make sure we can find the block. If not, send error back to the user.
120+
try {
121+
blockData = dataProvider.getBlockData(blockId)
122+
} catch {
123+
case e: Exception =>
124+
logError(s"Error opening block $blockId for request from $client", e)
125+
respondWithError(e.getMessage)
126+
return
127+
}
128+
129+
blockData match {
130+
case Left(segment) => writeFileSegment(segment)
131+
case Right(buf) => writeByteBuffer(buf)
105132
}
133+
106134
} // end of channelRead0
107135
}

core/src/main/scala/org/apache/spark/network/netty/server/ServerTester.scala renamed to core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,18 @@
1515
* limitations under the License.
1616
*/
1717

18-
package org.apache.spark.network.netty.server
18+
package org.apache.spark.storage
1919

20-
import org.apache.spark.SparkConf
21-
import org.apache.spark.network.netty.PathResolver
22-
import org.apache.spark.storage.{TestBlockId, FileSegment, BlockId}
20+
import java.nio.ByteBuffer
2321

24-
/** A simple main function for testing the server. */
25-
object ServerTester {
26-
def main(args: Array[String]): Unit = {
27-
new BlockServer(new SparkConf, new PathResolver {
28-
override def getBlockLocation(blockId: BlockId): FileSegment = {
29-
val file = new java.io.File(blockId.asInstanceOf[TestBlockId].id)
30-
new FileSegment(file, 0, file.length())
31-
}
32-
})
33-
Thread.sleep(1000000)
34-
}
22+
23+
/**
24+
* An interface for providing data for blocks.
25+
*
26+
* getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
27+
*
28+
* Aside from unit tests, [[BlockManager]] is the main class that implements this.
29+
*/
30+
private[spark] trait BlockDataProvider {
31+
def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
3532
}

core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ object BlockFetcherIterator {
277277

278278
bytesInFlight += req.size
279279
val sizeMap = req.blocks.toMap // so we can look up the size of each blockID
280-
val client = blockManager.diskBlockManager.nettyBlockClientFactory.createClient(
280+
val client = blockManager.nettyBlockClientFactory.createClient(
281281
cmId.host, req.address.nettyPort)
282282
val blocks = req.blocks.map(_._1.toString)
283283

0 commit comments

Comments
 (0)