Skip to content

Commit bade4d4

Browse files
GschiavonMarcos P
authored andcommitted
fix token flow (apache#44)
1 parent 8200918 commit bade4d4

File tree

1 file changed

+74
-51
lines changed

1 file changed

+74
-51
lines changed

core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala

Lines changed: 74 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import java.text.ParseException
2626
import scala.annotation.tailrec
2727
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
2828
import scala.util.Properties
29-
3029
import org.apache.commons.lang3.StringUtils
3130
import org.apache.hadoop.fs.Path
3231
import org.apache.hadoop.security.UserGroupInformation
@@ -41,11 +40,11 @@ import org.apache.ivy.core.settings.IvySettings
4140
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
4241
import org.apache.ivy.plugins.repository.file.FileRepository
4342
import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver}
44-
4543
import org.apache.spark.{SPARK_REVISION, SPARK_VERSION, SparkException, SparkUserAppException}
4644
import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL}
4745
import org.apache.spark.api.r.RUtils
4846
import org.apache.spark.deploy.rest._
47+
import org.apache.spark.internal.Logging
4948
import org.apache.spark.launcher.SparkLauncher
5049
import org.apache.spark.scheduler.{KerberosUser, KerberosUtil}
5150
import org.apache.spark.security.{ConfigSecurity, VaultHelper}
@@ -665,64 +664,88 @@ object SparkSubmit {
665664
}
666665

667666

668-
val mesosRoleEnv = (sys.env.get("VAULT_ROLE_ID"),
669-
sys.env.get("VAULT_SECRET_ID"))
670-
671-
val sparkRoleOpts = (args.sparkProperties.get("spark.secret.roleID"),
672-
args.sparkProperties.get("spark.secret.secretID"))
673-
674-
val tempToken = args.sparkProperties.get("spark.secret.vault.tempToken")
667+
val tempToken = (args.sparkProperties.get("spark.secret.vault.tempToken"),
668+
sys.env.get("VAULT_TEMP_TOKEN")) match {
669+
case (Some(property), env) => Option(property)
670+
case (property, Some(env)) => Option(env)
671+
case _ => None
672+
}
675673

676-
val sysEnvToken = sys.env.get("VAULT_TEMP_TOKEN")
674+
val roleSecret = (args.sparkProperties.get("spark.secret.roleID"),
675+
args.sparkProperties.get("spark.secret.secretID"),
676+
sys.env.get("VAULT_ROLE_ID"),
677+
sys.env.get("VAULT_SECRET_ID")) match {
678+
case (Some(roleProperty), Some(secretProperty), roleEnv, secretEnv) =>
679+
Option(roleProperty, secretProperty)
680+
case (roleProperty, secretProperty, Some(roleEnv), Some(secretEnv)) =>
681+
Option(roleEnv, secretEnv)
682+
case _ => None
683+
}
677684

678685
val vaultProtocol = args.sparkProperties.get("spark.secret.vault.protocol")
679-
val vaultHosts = args.sparkProperties.get("spark.secret.vault.hosts")
686+
val vaultHost = args.sparkProperties.get("spark.secret.vault.hosts")
680687
val vaultPort = args.sparkProperties.get("spark.secret.vault.port")
681688

682-
val (pincipal, keytab) =
683-
(mesosRoleEnv, sparkRoleOpts, tempToken, sysEnvToken,
684-
vaultProtocol, vaultHosts, vaultPort) match {
685-
686-
case ((roleIdEnv, secretIdEnv), (roleIdProp, secretIdProp), _, _,
687-
Some(protocol), Some(hosts), Some(port))
688-
if ((roleIdEnv.isDefined || roleIdProp.isDefined) &&
689-
(secretIdEnv.isDefined || secretIdProp.isDefined)) =>
690-
val vaultUrl = s"$protocol://${hosts.split(",")
691-
.map(host => s"$host:$port").mkString(",")}"
692-
693-
val roleId = roleIdEnv.getOrElse(roleIdProp.get)
694-
val secretId = secretIdEnv.getOrElse(secretIdProp.get)
695-
val vaultToken = VaultHelper.getTokenFromAppRole (vaultUrl, roleId, secretId)
696-
val environment = ConfigSecurity.prepareEnvironment(
697-
Option(vaultToken), Option (vaultUrl) )
698-
val principal = environment.get ("principal").getOrElse (args.principal)
699-
val keytab = environment.get ("keytabPath").getOrElse (args.keytab)
700-
701-
environment.foreach {
702-
case (key, value) => sysProps.put (key, value)
703-
}
704-
(principal, keytab)
689+
val vaultUrlParams = (vaultProtocol, vaultHost, vaultPort)
690+
val vaultUrl = buildVaultUrl(vaultUrlParams)
691+
lazy val vaultToken = getToken(tempToken, roleSecret, vaultUrl)
705692

706-
case (_, _, tempTokenProp, tempTokenEnv, Some(protocol), Some(hosts), Some(port))
707-
if (tempTokenProp.isDefined || tempTokenEnv.isDefined) =>
708-
val vaultUrl = s"$protocol://${hosts.split(",")
709-
.map(host => s"$host:${port}").mkString(",")}"
710-
val tempToken = tempTokenProp.getOrElse(tempTokenEnv.get)
711-
val vaultToken = VaultHelper.getRealToken (vaultUrl, tempToken)
712-
val environment = ConfigSecurity.prepareEnvironment(
713-
Option (vaultToken), Option (vaultUrl))
714-
val principal = environment.get ("principal").getOrElse (args.principal)
715-
val keytab = environment.get ("keytabPath").getOrElse (args.keytab)
716-
717-
environment.foreach {
718-
case (key, value) => sysProps.put (key, value)
719-
}
720-
(principal, keytab)
693+
val (principal, keytab) =
694+
if (vaultUrl.nonEmpty && vaultToken.isDefined) {
695+
val environment = ConfigSecurity.prepareEnvironment(
696+
Option (vaultToken.get), Option(vaultUrl))
697+
val principal = environment.getOrElse("principal", args.principal)
698+
val keytab = environment.getOrElse("keytabPath", args.keytab)
699+
700+
environment.foreach {
701+
case (key, value) => sysProps.put(key, value)
702+
}
703+
(principal, keytab)
721704

722-
case _ => (args.principal, args.keytab)
705+
} else {
706+
(args.principal, args.keytab)
723707
}
724708

725-
(childArgs, childClasspath, sysProps, childMainClass, pincipal, keytab)
709+
(childArgs, childClasspath, sysProps, childMainClass, principal, keytab)
710+
}
711+
712+
/**
713+
*
714+
* @param tempToken Temporal token, either Property one or Environment one
715+
* @param roleSecret Role and Secret ID, either Property one or Environment one
716+
* @param vaultUrl a Vault Url protocol://vaultHost:vaultPort
717+
* @return An option of a token
718+
*/
719+
private def getToken(tempToken: Option[String],
720+
roleSecret: Option[(String, String)],
721+
vaultUrl: String): Option[String] = {
722+
723+
(tempToken, roleSecret) match {
724+
case (Some(tempToken), _) => Some(VaultHelper.getRealToken(vaultUrl, tempToken))
725+
case (_, Some((role, secret))) =>
726+
Some(VaultHelper.getTokenFromAppRole(vaultUrl, role, secret))
727+
case _ => None
728+
}
729+
}
730+
731+
/**
732+
*
733+
* @param vaultUrlParams Is composed of Vault Protocol,
734+
* Vault Host and Vault Port
735+
* @return a Vault Url protocol://vaultHost:vaultPort
736+
*/
737+
private def buildVaultUrl(vaultUrlParams: (Option[String],
738+
Option[String],
739+
Option[String])): String = {
740+
741+
val vaultUrl = vaultUrlParams match {
742+
case (Some(protocol), Some(hosts), Some(port)) =>
743+
s"${protocol}://${
744+
hosts.split(",")
745+
.map(host => s"$host:${port}").mkString(",")}"
746+
case _ => ""
747+
}
748+
vaultUrl
726749
}
727750

728751
/**

0 commit comments

Comments
 (0)