Skip to content
/ hadoop Public
forked from apache/hadoop

Commit

Permalink
DX-66673: Backport HADOOP-15327. Upgrade MR ShuffleHandler to use Net…
Browse files Browse the repository at this point in the history
…ty4 apache#3259. Contributed by Szilard Nemeth.

(cherry picked from commit 5bb11ce)
Change-Id: If33ff1e8adda356dd5c6dadfb54fde9d31f84de7
  • Loading branch information
szilard-nemeth authored and szlta committed Jun 7, 2023
1 parent d01e231 commit 8d00592
Show file tree
Hide file tree
Showing 11 changed files with 1,602 additions and 571 deletions.
1 change: 1 addition & 0 deletions hadoop-client-modules/hadoop-client-runtime/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
<!-- Leave javax APIs that are stable -->
<!-- the jdk ships part of the javax.annotation namespace, so if we want to relocate this we'll have to care it out by class :( -->
<exclude>com.google.code.findbugs:jsr305</exclude>
<exclude>io.netty:*</exclude>
<exclude>io.dropwizard.metrics:metrics-core</exclude>
<exclude>org.eclipse.jetty:jetty-servlet</exclude>
<exclude>org.eclipse.jetty:jetty-security</exclude>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,15 @@

import org.apache.hadoop.thirdparty.com.google.common.annotations.VisibleForTesting;

class Fetcher<K,V> extends Thread {
@VisibleForTesting
public class Fetcher<K, V> extends Thread {

private static final Logger LOG = LoggerFactory.getLogger(Fetcher.class);

/** Number of ms before timing out a copy */
/** Number of ms before timing out a copy. */
private static final int DEFAULT_STALLED_COPY_TIMEOUT = 3 * 60 * 1000;

/** Basic/unit connection timeout (in milliseconds) */
/** Basic/unit connection timeout (in milliseconds). */
private final static int UNIT_CONNECT_TIMEOUT = 60 * 1000;

/* Default read timeout (in milliseconds) */
Expand All @@ -72,19 +73,21 @@ class Fetcher<K,V> extends Thread {
private static final String FETCH_RETRY_AFTER_HEADER = "Retry-After";

protected final Reporter reporter;
private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
@VisibleForTesting
public enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
CONNECTION, WRONG_REDUCE}

private final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";

@VisibleForTesting
public final static String SHUFFLE_ERR_GRP_NAME = "Shuffle Errors";
private final JobConf jobConf;
private final Counters.Counter connectionErrs;
private final Counters.Counter ioErrs;
private final Counters.Counter wrongLengthErrs;
private final Counters.Counter badIdErrs;
private final Counters.Counter wrongMapErrs;
private final Counters.Counter wrongReduceErrs;
protected final MergeManager<K,V> merger;
protected final ShuffleSchedulerImpl<K,V> scheduler;
protected final MergeManager<K, V> merger;
protected final ShuffleSchedulerImpl<K, V> scheduler;
protected final ShuffleClientMetrics metrics;
protected final ExceptionReporter exceptionReporter;
protected final int id;
Expand All @@ -111,7 +114,7 @@ private enum ShuffleErrors{IO_ERROR, WRONG_LENGTH, BAD_ID, WRONG_MAP,
private static SSLFactory sslFactory;

public Fetcher(JobConf job, TaskAttemptID reduceId,
ShuffleSchedulerImpl<K,V> scheduler, MergeManager<K,V> merger,
ShuffleSchedulerImpl<K, V> scheduler, MergeManager<K, V> merger,
Reporter reporter, ShuffleClientMetrics metrics,
ExceptionReporter exceptionReporter, SecretKey shuffleKey) {
this(job, reduceId, scheduler, merger, reporter, metrics,
Expand All @@ -120,7 +123,7 @@ public Fetcher(JobConf job, TaskAttemptID reduceId,

@VisibleForTesting
Fetcher(JobConf job, TaskAttemptID reduceId,
ShuffleSchedulerImpl<K,V> scheduler, MergeManager<K,V> merger,
ShuffleSchedulerImpl<K, V> scheduler, MergeManager<K, V> merger,
Reporter reporter, ShuffleClientMetrics metrics,
ExceptionReporter exceptionReporter, SecretKey shuffleKey,
int id) {
Expand Down Expand Up @@ -315,9 +318,8 @@ protected void copyFromHost(MapHost host) throws IOException {
return;
}

if(LOG.isDebugEnabled()) {
LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: "
+ maps);
if (LOG.isDebugEnabled()) {
LOG.debug("Fetcher " + id + " going to fetch from " + host + " for: " + maps);
}

// List of maps to be fetched yet
Expand Down Expand Up @@ -411,8 +413,8 @@ private void openConnectionWithRetry(URL url) throws IOException {
shouldWait = false;
} catch (IOException e) {
if (!fetchRetryEnabled) {
// throw exception directly if fetch's retry is not enabled
throw e;
// throw exception directly if fetch's retry is not enabled
throw e;
}
if ((Time.monotonicNow() - startTime) >= this.fetchRetryTimeout) {
LOG.warn("Failed to connect to host: " + url + "after "
Expand Down Expand Up @@ -489,7 +491,7 @@ private TaskAttemptID[] copyMapOutput(MapHost host,
DataInputStream input,
Set<TaskAttemptID> remaining,
boolean canRetry) throws IOException {
MapOutput<K,V> mapOutput = null;
MapOutput<K, V> mapOutput = null;
TaskAttemptID mapId = null;
long decompressedLength = -1;
long compressedLength = -1;
Expand Down Expand Up @@ -611,7 +613,7 @@ private void checkTimeoutOrRetry(MapHost host, IOException ioe)
// First time to retry.
long currentTime = Time.monotonicNow();
if (retryStartTime == 0) {
retryStartTime = currentTime;
retryStartTime = currentTime;
}

// Retry is not timeout, let's do retry with throwing an exception.
Expand All @@ -628,7 +630,7 @@ private void checkTimeoutOrRetry(MapHost host, IOException ioe)
}

/**
* Do some basic verification on the input received -- Being defensive
* Do some basic verification on the input received -- Being defensive.
* @param compressedLength
* @param decompressedLength
* @param forReduce
Expand Down Expand Up @@ -695,8 +697,7 @@ private URL getMapOutputURL(MapHost host, Collection<TaskAttemptID> maps
* only on the last failure. Instead of connecting with a timeout of
* X, we try connecting with a timeout of x < X but multiple times.
*/
private void connect(URLConnection connection, int connectionTimeout)
throws IOException {
private void connect(URLConnection connection, int connectionTimeout) throws IOException {
int unit = 0;
if (connectionTimeout < 0) {
throw new IOException("Invalid timeout "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableComparator;
import org.apache.hadoop.mapreduce.TaskCounter;
import org.apache.hadoop.mapreduce.task.reduce.Fetcher;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -37,6 +38,7 @@
import java.util.Formatter;
import java.util.Iterator;

import static org.apache.hadoop.mapreduce.task.reduce.Fetcher.SHUFFLE_ERR_GRP_NAME;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
Expand Down Expand Up @@ -87,6 +89,9 @@ public void testReduceFromPartialMem() throws Exception {
final long spill = c.findCounter(TaskCounter.SPILLED_RECORDS).getCounter();
assertTrue("Expected some records not spilled during reduce" + spill + ")",
spill < 2 * out); // spilled map records, some records at the reduce
long shuffleIoErrors =
c.getGroup(SHUFFLE_ERR_GRP_NAME).getCounter(Fetcher.ShuffleErrors.IO_ERROR.toString());
assertEquals(0, shuffleIoErrors);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import java.io.IOException;
import java.io.RandomAccessFile;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.handler.stream.ChunkedFile;
import org.apache.hadoop.thirdparty.com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.io.ReadaheadPool;
import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest;
Expand All @@ -31,8 +34,6 @@

import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED;

import org.jboss.netty.handler.stream.ChunkedFile;

public class FadvisedChunkedFile extends ChunkedFile {

private static final Logger LOG =
Expand Down Expand Up @@ -64,16 +65,16 @@ FileDescriptor getFd() {
}

@Override
public Object nextChunk() throws Exception {
public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception {
synchronized (closeLock) {
if (fd.valid()) {
if (manageOsCache && readaheadPool != null) {
readaheadRequest = readaheadPool
.readaheadStream(
identifier, fd, getCurrentOffset(), readaheadLength,
getEndOffset(), readaheadRequest);
identifier, fd, currentOffset(), readaheadLength,
endOffset(), readaheadRequest);
}
return super.nextChunk();
return super.readChunk(allocator);
} else {
return null;
}
Expand All @@ -88,12 +89,12 @@ public void close() throws Exception {
readaheadRequest = null;
}
if (fd.valid() &&
manageOsCache && getEndOffset() - getStartOffset() > 0) {
manageOsCache && endOffset() - startOffset() > 0) {
try {
NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(
identifier,
fd,
getStartOffset(), getEndOffset() - getStartOffset(),
startOffset(), endOffset() - startOffset(),
POSIX_FADV_DONTNEED);
} catch (Throwable t) {
LOG.warn("Failed to manage OS cache for " + identifier +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.nio.channels.FileChannel;
import java.nio.channels.WritableByteChannel;

import io.netty.channel.DefaultFileRegion;
import org.apache.hadoop.io.ReadaheadPool;
import org.apache.hadoop.io.ReadaheadPool.ReadaheadRequest;
import org.apache.hadoop.io.nativeio.NativeIO;
Expand All @@ -33,8 +34,6 @@

import static org.apache.hadoop.io.nativeio.NativeIO.POSIX.POSIX_FADV_DONTNEED;

import org.jboss.netty.channel.DefaultFileRegion;

import org.apache.hadoop.thirdparty.com.google.common.annotations.VisibleForTesting;

public class FadvisedFileRegion extends DefaultFileRegion {
Expand Down Expand Up @@ -77,8 +76,8 @@ public long transferTo(WritableByteChannel target, long position)
throws IOException {
if (readaheadPool != null && readaheadLength > 0) {
readaheadRequest = readaheadPool.readaheadStream(identifier, fd,
getPosition() + position, readaheadLength,
getPosition() + getCount(), readaheadRequest);
position() + position, readaheadLength,
position() + count(), readaheadRequest);
}

if(this.shuffleTransferToAllowed) {
Expand Down Expand Up @@ -147,22 +146,22 @@ long customShuffleTransfer(WritableByteChannel target, long position)


@Override
public void releaseExternalResources() {
protected void deallocate() {
if (readaheadRequest != null) {
readaheadRequest.cancel();
}
super.releaseExternalResources();
super.deallocate();
}

/**
* Call when the transfer completes successfully so we can advise the OS that
* we don't need the region to be cached anymore.
*/
public void transferSuccessful() {
if (manageOsCache && getCount() > 0) {
if (manageOsCache && count() > 0) {
try {
NativeIO.POSIX.getCacheManipulator().posixFadviseIfPossible(identifier,
fd, getPosition(), getCount(), POSIX_FADV_DONTNEED);
fd, position(), count(), POSIX_FADV_DONTNEED);
} catch (Throwable t) {
LOG.warn("Failed to manage OS cache for " + identifier, t);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.hadoop.mapred;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseEncoder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

class LoggingHttpResponseEncoder extends HttpResponseEncoder {
private static final Logger LOG = LoggerFactory.getLogger(LoggingHttpResponseEncoder.class);
private final boolean logStacktraceOfEncodingMethods;

LoggingHttpResponseEncoder(boolean logStacktraceOfEncodingMethods) {
this.logStacktraceOfEncodingMethods = logStacktraceOfEncodingMethods;
}

@Override
public boolean acceptOutboundMessage(Object msg) throws Exception {
printExecutingMethod();
LOG.info("OUTBOUND MESSAGE: " + msg);
return super.acceptOutboundMessage(msg);
}

@Override
protected void encodeInitialLine(ByteBuf buf, HttpResponse response) throws Exception {
LOG.debug("Executing method: {}, response: {}",
getExecutingMethodName(), response);
logStacktraceIfRequired();
super.encodeInitialLine(buf, response);
}

@Override
protected void encode(ChannelHandlerContext ctx, Object msg,
List<Object> out) throws Exception {
LOG.debug("Encoding to channel {}: {}", ctx.channel(), msg);
printExecutingMethod();
logStacktraceIfRequired();
super.encode(ctx, msg, out);
}

@Override
protected void encodeHeaders(HttpHeaders headers, ByteBuf buf) {
printExecutingMethod();
super.encodeHeaders(headers, buf);
}

@Override
public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise
promise) throws Exception {
LOG.debug("Writing to channel {}: {}", ctx.channel(), msg);
printExecutingMethod();
super.write(ctx, msg, promise);
}

private void logStacktraceIfRequired() {
if (logStacktraceOfEncodingMethods) {
LOG.debug("Stacktrace: ", new Throwable());
}
}

private void printExecutingMethod() {
String methodName = getExecutingMethodName(1);
LOG.debug("Executing method: {}", methodName);
}

private String getExecutingMethodName() {
return getExecutingMethodName(0);
}

private String getExecutingMethodName(int additionalSkipFrames) {
try {
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
// Array items (indices):
// 0: java.lang.Thread.getStackTrace(...)
// 1: TestShuffleHandler$LoggingHttpResponseEncoder.getExecutingMethodName(...)
int skipFrames = 2 + additionalSkipFrames;
String methodName = stackTrace[skipFrames].getMethodName();
String className = this.getClass().getSimpleName();
return className + "#" + methodName;
} catch (Throwable t) {
LOG.error("Error while getting execution method name", t);
return "unknown";
}
}
}
Loading

0 comments on commit 8d00592

Please sign in to comment.