Skip to content

Commit 6509af9

Browse files
committed
Include requester pays project id in DRS localizer call
1 parent f8f0dd4 commit 6509af9

File tree

2 files changed

+45
-5
lines changed

2 files changed

+45
-5
lines changed

supportedBackends/google/pipelines/v2beta/src/main/scala/cromwell/backend/google/pipelines/v2beta/api/Localization.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ import com.typesafe.config.ConfigFactory
66
import cromwell.backend.google.pipelines.common.action.ActionCommands.localizeFile
77
import cromwell.backend.google.pipelines.common.action.ActionLabels._
88
import cromwell.backend.google.pipelines.common.PipelinesApiConfigurationAttributes.GcsTransferConfiguration
9+
import cromwell.backend.google.pipelines.common.PipelinesApiFileInput
910
import cromwell.backend.google.pipelines.common.PipelinesApiJobPaths._
1011
import cromwell.backend.google.pipelines.common.api.PipelinesApiRequestFactory.CreatePipelineParameters
1112
import cromwell.backend.google.pipelines.v2beta.PipelinesConversions._
1213
import cromwell.backend.google.pipelines.v2beta.ToParameter.ops._
1314
import cromwell.backend.google.pipelines.v2beta.api.ActionBuilder.{EnhancedAction, cloudSdkShellAction}
1415
import cromwell.core.path.Path
16+
import cromwell.filesystems.drs.DrsPath
1517

1618
import scala.jdk.CollectionConverters._
1719

@@ -44,7 +46,12 @@ trait Localization {
4446
val runGcsLocalizationScript = cloudSdkShellAction(
4547
s"/bin/bash $gcsLocalizationContainerPath")(mounts = mounts, labels = localizationLabel)
4648

47-
val runDrsLocalization = Localization.drsAction(drsLocalizationManifestContainerPath, mounts, localizationLabel)
49+
// Requester pays project id is stored on each DrsPath, but will be the same for all DRS inputs to a
50+
// particular workflow because it's determined by the Google project set in workflow options.
51+
val requesterPaysProjectId: Option[String] = createPipelineParameters.inputOutputParameters.fileInputParameters.collect {
52+
case PipelinesApiFileInput(_, drsPath: DrsPath, _, _) => drsPath.requesterPaysProjectIdOption
53+
}.flatten.headOption
54+
val runDrsLocalization = Localization.drsAction(drsLocalizationManifestContainerPath, mounts, localizationLabel, requesterPaysProjectId)
4855

4956
// Any "classic" PAPI v2 one-at-a-time localizations for non-GCS inputs.
5057
val singletonLocalizations = createPipelineParameters.inputOutputParameters.fileInputParameters.flatMap(_.toActions(mounts).toList)
@@ -62,14 +69,20 @@ trait Localization {
6269

6370
object Localization {
6471

65-
def drsAction(manifestPath: Path, mounts: List[Mount], labels: Map[String, String]) = {
66-
// TODO: Is this an acceptable way to read this config?
72+
def drsAction(manifestPath: Path,
73+
mounts: List[Mount],
74+
labels: Map[String, String],
75+
requesterPaysProjectId: Option[String]
76+
): Action = {
6777
val config = ConfigFactory.load
6878
val marthaConfig = config.getConfig("filesystems.drs.global.config.martha")
6979
val drsConfig = DrsConfig.fromConfig(marthaConfig)
7080
val drsDockerImage = config.getString("drs.localization.docker-image")
7181

72-
val drsCommand = List("-m", manifestPath.pathAsString)
82+
val manifestArg = List("-m", manifestPath.pathAsString)
83+
val requesterPaysArg = requesterPaysProjectId.map(r => List("-r", r)).getOrElse(List.empty)
84+
val drsCommand = manifestArg ++ requesterPaysArg
85+
7386
val marthaEnv = DrsConfig.toEnv(drsConfig)
7487
ActionBuilder
7588
.withImage(drsDockerImage)

supportedBackends/google/pipelines/v2beta/src/test/scala/cromwell/backend/google/pipelines/v2beta/api/LocalizationSpec.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class LocalizationSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matcher
1919
val tagKey = "tag"
2020
val tagLabel = "myLabel"
2121

22-
val action = Localization.drsAction(manifestPath, Nil, Map(tagKey -> tagLabel))
22+
val action = Localization.drsAction(manifestPath, Nil, Map(tagKey -> tagLabel), None)
2323
action.keySet.asScala should contain theSameElementsAs
2424
Set("commands", "environment", "imageUri", "labels", "mounts")
2525

@@ -37,4 +37,31 @@ class LocalizationSpec extends AnyFlatSpec with CromwellTimeoutSpec with Matcher
3737
actionLabels.keySet.asScala should contain theSameElementsAs List("tag")
3838
actionLabels.get(tagKey) should be(tagLabel)
3939
}
40+
41+
it should "create the right action to localize DRS files using a manifest with requester pays" in {
42+
43+
val manifestPathString = s"path/to/${DrsLocalizationManifestName}"
44+
val manifestPath = DefaultPathBuilder.get(manifestPathString)
45+
val tagKey = "tag"
46+
val tagLabel = "myLabel"
47+
val requesterPaysProjectId = "123"
48+
49+
val action = Localization.drsAction(manifestPath, Nil, Map(tagKey -> tagLabel), Option(requesterPaysProjectId))
50+
action.keySet.asScala should contain theSameElementsAs
51+
Set("commands", "environment", "imageUri", "labels", "mounts")
52+
53+
action.get("commands") should be(a[java.util.List[_]])
54+
action.get("commands").asInstanceOf[java.util.List[_]] should contain theSameElementsAs List(
55+
"-m", manifestPathString, "-r", requesterPaysProjectId
56+
)
57+
58+
action.get("mounts") should be(a[java.util.List[_]])
59+
action.get("mounts").asInstanceOf[java.util.List[_]] should be (empty)
60+
61+
action.get("imageUri") should be("somerepo/drs-downloader:tagged")
62+
63+
val actionLabels = action.get("labels").asInstanceOf[java.util.Map[_, _]]
64+
actionLabels.keySet.asScala should contain theSameElementsAs List("tag")
65+
actionLabels.get(tagKey) should be(tagLabel)
66+
}
4067
}

0 commit comments

Comments
 (0)