Skip to content

Commit 9f354aa

Browse files
THWisemanaednichols
authored andcommitted
[WX-1260] Acquire sas token from task runner (#7241)
Co-authored-by: Adam Nichols <anichols@broadinstitute.org>
1 parent a88fd3f commit 9f354aa

File tree

13 files changed

+455
-55
lines changed

13 files changed

+455
-55
lines changed

backend/src/main/scala/cromwell/backend/standard/StandardAsyncExecutionActor.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package cromwell.backend.standard
22

33
import java.io.IOException
4-
54
import akka.actor.{Actor, ActorLogging, ActorRef}
65
import akka.event.LoggingReceive
76
import cats.implicits._
@@ -329,7 +328,7 @@ trait StandardAsyncExecutionActor
329328
}
330329

331330
/** Any custom code that should be run within commandScriptContents before the instantiated command. */
332-
def scriptPreamble: String = ""
331+
def scriptPreamble: ErrorOr[String] = "".valid
333332

334333
def cwd: Path = commandDirectory
335334
def rcPath: Path = cwd./(jobPaths.returnCodeFilename)
@@ -427,10 +426,12 @@ trait StandardAsyncExecutionActor
427426
|find . -type d -exec sh -c '[ -z "$$(ls -A '"'"'{}'"'"')" ] && touch '"'"'{}'"'"'/.file' \\;
428427
|)""".stripMargin)
429428

429+
val errorOrPreamble: ErrorOr[String] = scriptPreamble
430+
430431
// The `tee` trickery below is to be able to redirect to known filenames for CWL while also streaming
431432
// stdout and stderr for PAPI to periodically upload to cloud storage.
432433
// https://stackoverflow.com/questions/692000/how-do-i-write-stderr-to-a-file-while-using-tee-with-a-pipe
433-
(errorOrDirectoryOutputs, errorOrGlobFiles).mapN((directoryOutputs, globFiles) =>
434+
(errorOrDirectoryOutputs, errorOrGlobFiles, errorOrPreamble).mapN((directoryOutputs, globFiles, preamble) =>
434435
s"""|#!$jobShell
435436
|DOCKER_OUTPUT_DIR_LINK
436437
|cd ${cwd.pathAsString}
@@ -464,7 +465,7 @@ trait StandardAsyncExecutionActor
464465
|)
465466
|mv $rcTmpPath $rcPath
466467
|""".stripMargin
467-
.replace("SCRIPT_PREAMBLE", scriptPreamble)
468+
.replace("SCRIPT_PREAMBLE", preamble)
468469
.replace("ENVIRONMENT_VARIABLES", environmentVariables)
469470
.replace("INSTANTIATED_COMMAND", commandString)
470471
.replace("SCRIPT_EPILOGUE", scriptEpilogue)

filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import java.nio.file.spi.FileSystemProvider
1515
import java.time.temporal.ChronoUnit
1616
import java.time.{Duration, OffsetDateTime}
1717
import java.util.UUID
18+
import scala.collection.mutable
1819
import scala.jdk.CollectionConverters._
1920
import scala.util.{Failure, Success, Try}
2021

@@ -160,12 +161,14 @@ object BlobSasTokenGenerator {
160161
*/
161162
def createBlobTokenGenerator(workspaceManagerClient: WorkspaceManagerApiClientProvider,
162163
overrideWsmAuthToken: Option[String]): BlobSasTokenGenerator = {
163-
WSMBlobSasTokenGenerator(workspaceManagerClient, overrideWsmAuthToken)
164+
new WSMBlobSasTokenGenerator(workspaceManagerClient, overrideWsmAuthToken)
164165
}
165166

166167
}
167168

168-
case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClientProvider,
169+
case class WSMTerraCoordinates(wsmEndpoint: String, workspaceId: UUID, containerResourceId: UUID)
170+
171+
class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClientProvider,
169172
overrideWsmAuthToken: Option[String]) extends BlobSasTokenGenerator {
170173

171174
/**
@@ -178,17 +181,14 @@ case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClient
178181
* @return an AzureSasCredential for accessing a blob container
179182
*/
180183
def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] = {
181-
val wsmAuthToken: Try[String] = overrideWsmAuthToken match {
182-
case Some(t) => Success(t)
183-
case None => AzureCredentials.getAccessToken(None).toTry
184-
}
184+
val wsmAuthToken: Try[String] = getWsmAuth
185185
container.workspaceId match {
186186
// If this is a Terra workspace, request a token from WSM
187187
case Success(workspaceId) => {
188188
(for {
189189
wsmAuth <- wsmAuthToken
190190
wsmAzureResourceClient = wsmClientProvider.getControlledAzureResourceApi(wsmAuth)
191-
resourceId <- getContainerResourceId(workspaceId, container, wsmAuth)
191+
resourceId <- getContainerResourceId(workspaceId, container, Option(wsmAuth))
192192
sasToken <- wsmAzureResourceClient.createAzureStorageContainerSasToken(workspaceId, resourceId)
193193
} yield sasToken).recoverWith {
194194
// If the storage account was still not found in WSM, this may be a public filesystem
@@ -201,9 +201,59 @@ case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClient
201201
}
202202
}
203203

204-
def getContainerResourceId(workspaceId: UUID, container: BlobContainerName, wsmAuth : String): Try[UUID] = {
205-
val wsmResourceClient = wsmClientProvider.getResourceApi(wsmAuth)
206-
wsmResourceClient.findContainerResourceId(workspaceId, container)
204+
private val cachedContainerResourceIds = new mutable.HashMap[BlobContainerName, UUID]()
205+
206+
// Optionally provide wsmAuth to avoid acquiring it twice in generateBlobSasToken.
207+
// In the case that the resourceId is not cached and no auth is provided, this function will acquire a new auth as necessary.
208+
private def getContainerResourceId(workspaceId: UUID, container: BlobContainerName, precomputedWsmAuth: Option[String]): Try[UUID] = {
209+
cachedContainerResourceIds.get(container) match {
210+
case Some(id) => Try(id) //cache hit
211+
case _ => { //cache miss
212+
val auth: Try[String] = precomputedWsmAuth.map(auth => Try(auth)).getOrElse(getWsmAuth)
213+
val resourceId = for {
214+
wsmAuth <- auth
215+
wsmResourceApi = wsmClientProvider.getResourceApi(wsmAuth)
216+
resourceId <- wsmResourceApi.findContainerResourceId(workspaceId, container)
217+
} yield resourceId
218+
resourceId.map(id => cachedContainerResourceIds.put(container, id)) //NB: Modifying cache state here.
219+
cachedContainerResourceIds.get(container) match {
220+
case Some(uuid) => Try(uuid)
221+
case _ => Failure(new NoSuchElementException("Could not retrieve container resource ID from WSM"))
222+
}
223+
}
224+
}
225+
}
226+
227+
private def getWsmAuth: Try[String] = {
228+
overrideWsmAuthToken match {
229+
case Some(t) => Success(t)
230+
case None => AzureCredentials.getAccessToken(None).toTry
231+
}
232+
}
233+
234+
private def parseTerraWorkspaceIdFromPath(blobPath: BlobPath): Try[UUID] = {
235+
if (blobPath.container.value.startsWith("sc-")) Try(UUID.fromString(blobPath.container.value.substring(3)))
236+
else Failure(new Exception("Could not parse workspace ID from storage container. Are you sure this is a file in a Terra Workspace?"))
237+
}
238+
239+
/**
240+
* Return a REST endpoint that will reply with a sas token for the blob storage container associated with the provided blob path.
241+
* @param blobPath A blob path of a file living in a blob container that WSM knows about (likely a workspace container).
242+
* @param tokenDuration How long will the token last after being generated. Default is 8 hours. Sas tokens won't last longer than 24h.
243+
* NOTE: If a blobPath is provided for a file in a container other than what this token generator was constructed for,
244+
* this function will make two REST requests. Otherwise, the relevant data is already cached locally.
245+
*/
246+
def getWSMSasFetchEndpoint(blobPath: BlobPath, tokenDuration: Option[Duration] = None): Try[String] = {
247+
val wsmEndpoint = wsmClientProvider.getBaseWorkspaceManagerUrl
248+
val lifetimeQueryParameters: String = tokenDuration.map(d => s"?sasExpirationDuration=${d.toSeconds.intValue}").getOrElse("")
249+
val terraInfo: Try[WSMTerraCoordinates] = for {
250+
workspaceId <- parseTerraWorkspaceIdFromPath(blobPath)
251+
containerResourceId <- getContainerResourceId(workspaceId, blobPath.container, None)
252+
coordinates = WSMTerraCoordinates(wsmEndpoint, workspaceId, containerResourceId)
253+
} yield coordinates
254+
terraInfo.map{terraCoordinates =>
255+
s"${terraCoordinates.wsmEndpoint}/api/workspaces/v1/${terraCoordinates.workspaceId.toString}/resources/controlled/azure/storageContainer/${terraCoordinates.containerResourceId.toString}/getSasToken${lifetimeQueryParameters}"
256+
}
207257
}
208258
}
209259

filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, con
185185
* @return Path string relative to the container root.
186186
*/
187187
def pathWithoutContainer : String = pathString
188-
188+
189+
def getFilesystemManager: BlobFileSystemManager = fsm
190+
189191
override def getSymlinkSafePath(options: LinkOption*): Path = toAbsolutePath
192+
190193
}

filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import scala.util.Try
2020
trait WorkspaceManagerApiClientProvider {
2121
def getControlledAzureResourceApi(token: String): WsmControlledAzureResourceApi
2222
def getResourceApi(token: String): WsmResourceApi
23+
def getBaseWorkspaceManagerUrl: String
2324
}
2425

2526
class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManagerURL) extends WorkspaceManagerApiClientProvider {
@@ -40,6 +41,7 @@ class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManag
4041
apiClient.setAccessToken(token)
4142
WsmControlledAzureResourceApi(new ControlledAzureResourceApi(apiClient))
4243
}
44+
def getBaseWorkspaceManagerUrl: String = baseWorkspaceManagerUrl.value
4345
}
4446

4547
case class WsmResourceApi(resourcesApi : ResourceApi) {

supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/GcpBatchAsyncBackendJobExecutionActor.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -663,12 +663,12 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
663663
private val DockerMonitoringLogPath: Path = GcpBatchWorkingDisk.MountPoint.resolve(gcpBatchCallPaths.batchMonitoringLogFilename)
664664
private val DockerMonitoringScriptPath: Path = GcpBatchWorkingDisk.MountPoint.resolve(gcpBatchCallPaths.batchMonitoringScriptFilename)
665665

666-
override def scriptPreamble: String = {
666+
override def scriptPreamble: ErrorOr[String] = {
667667
if (monitoringOutput.isDefined) {
668668
s"""|touch $DockerMonitoringLogPath
669669
|chmod u+x $DockerMonitoringScriptPath
670-
|$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin
671-
} else ""
670+
|$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin.valid
671+
} else "".valid
672672
}
673673

674674
private[actors] def generateInputs(jobDescriptor: BackendJobDescriptor): Set[GcpBatchInput] = {

supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,12 +380,12 @@ class PipelinesApiAsyncBackendJobExecutionActor(override val standardParams: Sta
380380

381381
private lazy val isDockerImageCacheUsageRequested = runtimeAttributes.useDockerImageCache.getOrElse(useDockerImageCache(jobDescriptor.workflowDescriptor))
382382

383-
override def scriptPreamble: String = {
383+
override def scriptPreamble: ErrorOr[String] = {
384384
if (monitoringOutput.isDefined) {
385385
s"""|touch $DockerMonitoringLogPath
386386
|chmod u+x $DockerMonitoringScriptPath
387387
|$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin
388-
} else ""
388+
}.valid else "".valid
389389
}
390390

391391
override def globParentDirectory(womGlobFile: WomGlobFile): Path = {

0 commit comments

Comments
 (0)