Skip to content

[SPARK-21655][YARN] Support Kill CLI for Yarn mode #18897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/network-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@
<artifactId>jackson-annotations</artifactId>
</dependency>

<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-common</artifactId>
</dependency>

<!-- Provided dependencies -->
<dependency>
<groupId>org.slf4j</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public class TransportClient implements Closeable {
private final Channel channel;
private final TransportResponseHandler handler;
@Nullable private String clientId;
@Nullable private String clientUser;
private volatile boolean timedOut;

public TransportClient(Channel channel, TransportResponseHandler handler) {
Expand Down Expand Up @@ -114,6 +115,25 @@ public void setClientId(String id) {
this.clientId = id;
}

/**
* Returns the user name used by the client to authenticate itself when authentication is enabled.
*
* @return The client User Name, or null if authentication is disabled.
*/
public String getClientUser() {
return clientUser;
}

/**
* Sets the authenticated client's user name. This is meant to be used by the authentication layer.
*
* Trying to set a different client User Name after it's been set will result in an exception.
*/
public void setClientUser(String user) {
Preconditions.checkState(clientUser == null, "Client User Name has already been set.");
this.clientUser = user;
}

/**
* Requests a single chunk from the remote side, from the pre-negotiated streamId.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder sec
*/
@Override
public void doBootstrap(TransportClient client, Channel channel) {
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption());
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption(), conf);
try {
byte[] payload = saslClient.firstToken();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
// First message in the handshake, setup the necessary state.
client.setClientId(saslMessage.appId);
saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder,
conf.saslServerAlwaysEncrypt());
conf.saslServerAlwaysEncrypt(), conf);
}

byte[] response;
Expand All @@ -114,6 +114,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
// method returns. This assumes that the code ensures, through other means, that no outbound
// messages are being written to the channel while negotiation is still going on.
if (saslServer.isComplete()) {
client.setClientUser(saslServer.getUserName());
if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) {
logger.debug("SASL authentication successful for channel {}", client);
complete(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

import static org.apache.spark.network.sasl.SparkSaslServer.*;

import org.apache.spark.network.util.TransportConf;

/**
* A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the
* initial state to the "authenticated" state. This client initializes the protocol via a
Expand All @@ -48,12 +50,25 @@ public class SparkSaslClient implements SaslEncryptionBackend {
private final String secretKeyId;
private final SecretKeyHolder secretKeyHolder;
private final String expectedQop;
private TransportConf conf;
private SaslClient saslClient;

public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) {
public SparkSaslClient(
String secretKeyId,
SecretKeyHolder secretKeyHolder,
boolean alwaysEncrypt) {
this(secretKeyId,secretKeyHolder,alwaysEncrypt, null);
}

public SparkSaslClient(
String secretKeyId,
SecretKeyHolder secretKeyHolder,
boolean encrypt,
TransportConf conf) {
this.secretKeyId = secretKeyId;
this.secretKeyHolder = secretKeyHolder;
this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH;
this.conf = conf;

Map<String, String> saslProps = ImmutableMap.<String, String>builder()
.put(Sasl.QOP, expectedQop)
Expand Down Expand Up @@ -131,11 +146,23 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
if (callback instanceof NameCallback) {
logger.trace("SASL client callback: setting username");
NameCallback nc = (NameCallback) callback;
nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
if (conf != null && conf.isConnectionUsingTokens()) {
// Token Identifier is already encoded
nc.setName(secretKeyHolder.getSaslUser(secretKeyId));
} else {
nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
}

} else if (callback instanceof PasswordCallback) {
logger.trace("SASL client callback: setting password");
PasswordCallback pc = (PasswordCallback) callback;
pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
if (conf != null && conf.isConnectionUsingTokens()) {
// Token Identifier is already encoded
pc.setPassword(secretKeyHolder.getSecretKey(secretKeyId).toCharArray());
} else {
pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));

}
} else if (callback instanceof RealmCallback) {
logger.trace("SASL client callback: setting realm");
RealmCallback rc = (RealmCallback) callback;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
Expand All @@ -40,6 +42,12 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.hadoop.security.token.SecretManager.InvalidToken;
import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager;
import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier;

import org.apache.spark.network.util.TransportConf;

/**
* A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the
* initial state to the "authenticated" state. (It is not a server in the sense of accepting
Expand Down Expand Up @@ -73,14 +81,25 @@ public class SparkSaslServer implements SaslEncryptionBackend {
/** Identifier for a certain secret key within the secretKeyHolder. */
private final String secretKeyId;
private final SecretKeyHolder secretKeyHolder;
private TransportConf conf;
private String clientUser;
private SaslServer saslServer;

public SparkSaslServer(
String secretKeyId,
SecretKeyHolder secretKeyHolder,
boolean alwaysEncrypt) {
this(secretKeyId, secretKeyHolder, alwaysEncrypt, null);
}

public SparkSaslServer(
String secretKeyId,
SecretKeyHolder secretKeyHolder,
boolean alwaysEncrypt,
TransportConf conf) {
this.secretKeyId = secretKeyId;
this.secretKeyHolder = secretKeyHolder;
this.conf = conf;

// Sasl.QOP is a comma-separated list of supported values. The value that allows encryption
// is listed first since it's preferred over the non-encrypted one (if the client also
Expand All @@ -98,6 +117,13 @@ public SparkSaslServer(
}
}

/**
* Returns the user name of the client.
*/
public String getUserName() {
return clientUser;
}

/**
* Determines whether the authentication exchange has completed successfully.
*/
Expand Down Expand Up @@ -156,15 +182,16 @@ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException {
private class DigestCallbackHandler implements CallbackHandler {
@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
NameCallback nc = null;
PasswordCallback pc = null;
for (Callback callback : callbacks) {
if (callback instanceof NameCallback) {
logger.trace("SASL server callback: setting username");
NameCallback nc = (NameCallback) callback;
nc = (NameCallback) callback;
nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId)));
} else if (callback instanceof PasswordCallback) {
logger.trace("SASL server callback: setting password");
PasswordCallback pc = (PasswordCallback) callback;
pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
pc = (PasswordCallback) callback;
} else if (callback instanceof RealmCallback) {
logger.trace("SASL server callback: setting realm");
RealmCallback rc = (RealmCallback) callback;
Expand All @@ -182,10 +209,45 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback
throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback");
}
}
if (pc != null) {
if (conf != null && conf.isConnectionUsingTokens()) {
ClientToAMTokenSecretManager secretManager = new ClientToAMTokenSecretManager(null,
decodeMasterKey(secretKeyHolder.getSecretKey(secretKeyId)));
ClientToAMTokenIdentifier identifier = getIdentifier(nc.getDefaultName());
clientUser = identifier.getUser().getShortUserName();
pc.setPassword(getClientToAMSecretKey(identifier, secretManager));
} else {
clientUser = nc.getDefaultName();
pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId)));
}
}
}
}

/* Encode a byte[] identifier as a Base64-encoded string. */
/** Creates an ClientToAMTokenIdentifier from the encoded Base-64 String */
private static ClientToAMTokenIdentifier getIdentifier(String id) throws InvalidToken {
byte[] tokenId = byteBufToByte(Base64.decode(
Unpooled.wrappedBuffer(id.getBytes(StandardCharsets.UTF_8))));

ClientToAMTokenIdentifier tokenIdentifier = new ClientToAMTokenIdentifier();
try {
tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(tokenId)));
} catch (IOException e) {
throw (InvalidToken) new InvalidToken(
"Can't de-serialize tokenIdentifier").initCause(e);
}
return tokenIdentifier;
}

/** Returns an Base64-encoded secretKey created from the Identifier and the secretmanager */
private char[] getClientToAMSecretKey(ClientToAMTokenIdentifier tokenid,
ClientToAMTokenSecretManager secretManager) throws InvalidToken {
byte[] password = secretManager.retrievePassword(tokenid);
return Base64.encode(Unpooled.wrappedBuffer(password)).toString(StandardCharsets.UTF_8)
.toCharArray();
}

/** Encode a String identifier as a Base64-encoded string. */
public static String encodeIdentifier(String identifier) {
Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
return getBase64EncodedString(identifier);
Expand All @@ -197,6 +259,25 @@ public static char[] encodePassword(String password) {
return getBase64EncodedString(password).toCharArray();
}

/** Decode a base64-encoded indentifier as a String. */
public static String decodeIdentifier(String identifier) {
Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled");
return Base64.decode(Unpooled.wrappedBuffer(identifier.getBytes(StandardCharsets.UTF_8)))
.toString(StandardCharsets.UTF_8);
}

/** Decode a base64-encoded MasterKey as a byte[] array. */
public static byte[] decodeMasterKey(String masterKey) {
ByteBuf masterKeyByteBuf = Base64.decode(Unpooled.wrappedBuffer(masterKey.getBytes(StandardCharsets.UTF_8)));
return byteBufToByte(masterKeyByteBuf);
}

/** Convert an ByteBuf to a byte[] array. */
private static byte[] byteBufToByte(ByteBuf byteBuf) {
byte[] byteArray = new byte[byteBuf.readableBytes()];
byteBuf.readBytes(byteArray);
return byteArray;
}
/** Return a Base64-encoded string. */
private static String getBase64EncodedString(String str) {
ByteBuf byteBuf = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ public int numConnectionsPerPeer() {
/** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */
public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); }

/** If true, the current RPC connection is a Client to AM connection */
public boolean isConnectionUsingTokens() { return conf.getBoolean("spark.rpc.connectionUsingTokens", false); }

/**
* Receive buffer size (SO_RCVBUF).
* Note: the optimal size for receive buffer and send buffer should be
Expand Down
23 changes: 18 additions & 5 deletions core/src/main/scala/org/apache/spark/SecurityManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,11 @@ private[spark] class SecurityManager(
setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", ""))
setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", ""))

setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", ""));
setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", ""));
setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", ""))
setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", ""))

private val secretKey = generateSecretKey()
private var identifier = "sparkSaslUser"
private var secretKey = generateSecretKey()
logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") +
"; ui acls " + (if (aclsOn) "enabled" else "disabled") +
"; users with view permissions: " + viewAcls.toString() +
Expand Down Expand Up @@ -533,11 +534,23 @@ private[spark] class SecurityManager(

/**
* Gets the user used for authenticating SASL connections.
* For now use a single hardcoded user.
* @return the SASL user as a String
*/
def getSaslUser(): String = "sparkSaslUser"
def getSaslUser(): String = identifier

/**
* This can be a user name or unique identifier
*/
def setSaslUser(ident: String) {
identifier = ident
}

/**
* set the secret key
*/
def setSecretKey(secret: String) {
secretKey = secret
}
/**
* Gets the secret key.
* @return the secret key as a String if authentication is enabled, otherwise returns null
Expand Down
Loading