Skip to content

Commit

Permalink
HADOOP-17749. Remove lock contention in SelectorPool of SocketIOWithT…
Browse files Browse the repository at this point in the history
…imeout (#3080)

(cherry picked from commit a5db683)
  • Loading branch information
liangxs authored and ferhui committed Jul 7, 2021
1 parent 4912e4c commit b7108cf
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.hadoop.util.Time;
import org.slf4j.Logger;
Expand All @@ -48,8 +49,6 @@ abstract class SocketIOWithTimeout {
private long timeout;
private boolean closed = false;

private static SelectorPool selector = new SelectorPool();

/* A timeout value of 0 implies wait for ever.
* We should have a value of timeout that implies zero wait.. i.e.
* read or write returns immediately.
Expand Down Expand Up @@ -154,7 +153,7 @@ int doIO(ByteBuffer buf, int ops) throws IOException {
//now wait for socket to be ready.
int count = 0;
try {
count = selector.select(channel, ops, timeout);
count = SelectorPool.select(channel, ops, timeout);
} catch (IOException e) { //unexpected IOException.
closed = true;
throw e;
Expand Down Expand Up @@ -200,7 +199,7 @@ static void connect(SocketChannel channel,
// we might have to call finishConnect() more than once
// for some channels (with user level protocols)

int ret = selector.select((SelectableChannel)channel,
int ret = SelectorPool.select(channel,
SelectionKey.OP_CONNECT, timeoutLeft);

if (ret > 0 && channel.finishConnect()) {
Expand Down Expand Up @@ -242,7 +241,7 @@ static void connect(SocketChannel channel,
*/
void waitForIO(int ops) throws IOException {

if (selector.select(channel, ops, timeout) == 0) {
if (SelectorPool.select(channel, ops, timeout) == 0) {
throw new SocketTimeoutException(timeoutExceptionString(channel, timeout,
ops));
}
Expand Down Expand Up @@ -280,12 +279,17 @@ private static String timeoutExceptionString(SelectableChannel channel,
* This maintains a pool of selectors. These selectors are closed
* once they are idle (unused) for a few seconds.
*/
private static class SelectorPool {
private static final class SelectorPool {

private static class SelectorInfo {
Selector selector;
long lastActivityTime;
LinkedList<SelectorInfo> queue;
private static final class SelectorInfo {
private final SelectorProvider provider;
private final Selector selector;
private long lastActivityTime;

private SelectorInfo(SelectorProvider provider, Selector selector) {
this.provider = provider;
this.selector = selector;
}

void close() {
if (selector != null) {
Expand All @@ -298,16 +302,11 @@ void close() {
}
}

private static class ProviderInfo {
SelectorProvider provider;
LinkedList<SelectorInfo> queue; // lifo
ProviderInfo next;
}
private static ConcurrentHashMap<SelectorProvider, ConcurrentLinkedDeque
<SelectorInfo>> providerMap = new ConcurrentHashMap<>();

private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds.

private ProviderInfo providerList = null;

/**
* Waits on the channel with the given timeout using one of the
* cached selectors. It also removes any cached selectors that are
Expand All @@ -319,7 +318,7 @@ private static class ProviderInfo {
* @return
* @throws IOException
*/
int select(SelectableChannel channel, int ops, long timeout)
static int select(SelectableChannel channel, int ops, long timeout)
throws IOException {

SelectorInfo info = get(channel);
Expand Down Expand Up @@ -385,35 +384,18 @@ int select(SelectableChannel channel, int ops, long timeout)
* @return
* @throws IOException
*/
private synchronized SelectorInfo get(SelectableChannel channel)
private static SelectorInfo get(SelectableChannel channel)
throws IOException {
SelectorInfo selInfo = null;

SelectorProvider provider = channel.provider();

// pick the list : rarely there is more than one provider in use.
ProviderInfo pList = providerList;
while (pList != null && pList.provider != provider) {
pList = pList.next;
}
if (pList == null) {
//LOG.info("Creating new ProviderInfo : " + provider.toString());
pList = new ProviderInfo();
pList.provider = provider;
pList.queue = new LinkedList<SelectorInfo>();
pList.next = providerList;
providerList = pList;
}

LinkedList<SelectorInfo> queue = pList.queue;

if (queue.isEmpty()) {
ConcurrentLinkedDeque<SelectorInfo> infoQ = providerMap.computeIfAbsent(
provider, k -> new ConcurrentLinkedDeque<>());

SelectorInfo selInfo = infoQ.pollLast(); // last in first out
if (selInfo == null) {
Selector selector = provider.openSelector();
selInfo = new SelectorInfo();
selInfo.selector = selector;
selInfo.queue = queue;
} else {
selInfo = queue.removeLast();
// selInfo will be put into infoQ after `#release()`
selInfo = new SelectorInfo(provider, selector);
}

trimIdleSelectors(Time.now());
Expand All @@ -426,34 +408,39 @@ private synchronized SelectorInfo get(SelectableChannel channel)
*
* @param info
*/
private synchronized void release(SelectorInfo info) {
private static void release(SelectorInfo info) {
long now = Time.now();
trimIdleSelectors(now);
info.lastActivityTime = now;
info.queue.addLast(info);
// SelectorInfos in queue are sorted by lastActivityTime
providerMap.get(info.provider).addLast(info);
}

private static AtomicBoolean trimming = new AtomicBoolean(false);

/**
* Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not
* traverse the whole list, just over the one that have crossed
* the timeout.
*/
private void trimIdleSelectors(long now) {
private static void trimIdleSelectors(long now) {
if (!trimming.compareAndSet(false, true)) {
return;
}

long cutoff = now - IDLE_TIMEOUT;

for(ProviderInfo pList=providerList; pList != null; pList=pList.next) {
if (pList.queue.isEmpty()) {
continue;
}
for(Iterator<SelectorInfo> it = pList.queue.iterator(); it.hasNext();) {
SelectorInfo info = it.next();
if (info.lastActivityTime > cutoff) {
for (ConcurrentLinkedDeque<SelectorInfo> infoQ : providerMap.values()) {
SelectorInfo oldest;
while ((oldest = infoQ.peekFirst()) != null) {
if (oldest.lastActivityTime <= cutoff && infoQ.remove(oldest)) {
oldest.close();
} else {
break;
}
it.remove();
info.close();
}
}

trimming.set(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
import java.net.SocketTimeoutException;
import java.nio.channels.Pipe;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.test.MultithreadedTestUtil;
Expand Down Expand Up @@ -186,6 +191,46 @@ public void doWork() throws Exception {
}
}

@Test
public void testSocketIOWithTimeoutByMultiThread() throws Exception {
CountDownLatch latch = new CountDownLatch(1);
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, TIMEOUT);
Pipe.SinkChannel sink = pipe.sink();
OutputStream out = new SocketOutputStream(sink, TIMEOUT)) {

byte[] writeBytes = TEST_STRING.getBytes();
byte[] readBytes = new byte[writeBytes.length];
latch.await();

out.write(writeBytes);
doIO(null, out, TIMEOUT);

in.read(readBytes);
assertArrayEquals(writeBytes, readBytes);
doIO(in, null, TIMEOUT);
}
} catch (Exception e) {
fail(e.getMessage());
}
};

int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}

Thread.sleep(1000);
latch.countDown();

threadPool.shutdown();
assertTrue(threadPool.awaitTermination(3, TimeUnit.SECONDS));
}

@Test
public void testSocketIOWithTimeoutInterrupted() throws Exception {
Pipe pipe = Pipe.open();
Expand Down Expand Up @@ -223,4 +268,38 @@ public void doWork() throws Exception {
ctx.stop();
}
}

@Test
public void testSocketIOWithTimeoutInterruptedByMultiThread()
throws Exception {
final int timeout = TIMEOUT * 10;
AtomicLong readCount = new AtomicLong();
AtomicLong exceptionCount = new AtomicLong();
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, timeout)) {
in.read();
readCount.incrementAndGet();
} catch (InterruptedIOException ste) {
exceptionCount.incrementAndGet();
}
} catch (Exception e) {
fail(e.getMessage());
}
};

int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}
Thread.sleep(1000);
threadPool.shutdownNow();
threadPool.awaitTermination(1, TimeUnit.SECONDS);

assertEquals(0, readCount.get());
assertEquals(threadCnt, exceptionCount.get());
}
}

0 comments on commit b7108cf

Please sign in to comment.