Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
*/
package cromwell.backend.impl.aws

import java.security.MessageDigest

import cats.data.ReaderT._
import cats.data.{Kleisli, ReaderT}
import cats.effect.{Async, Timer}
Expand All @@ -53,8 +51,9 @@ import software.amazon.awssdk.services.s3.S3Client
import software.amazon.awssdk.services.s3.model.{GetObjectRequest, HeadObjectRequest, NoSuchKeyException, PutObjectRequest}
import wdl4s.parser.MemoryUnit

import scala.jdk.CollectionConverters._
import java.security.MessageDigest
import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.util.{Random, Try}

/**
Expand Down Expand Up @@ -256,7 +255,6 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
SubmitJobRequest.builder()
.jobName(sanitize(jobDescriptor.taskCall.fullyQualifiedName))
.parameters(parameters.collect({ case i: AwsBatchInput => i.toStringString }).toMap.asJava)

//provide job environment variables, vcpu and memory
.containerOverrides(
ContainerOverrides.builder
Expand All @@ -276,6 +274,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
)
.build()
)
.tags(runtimeAttributes.resourceTags.asJava)
.jobQueue(runtimeAttributes.queueArn)
.jobDefinition(definitionArn)
.build
Expand Down Expand Up @@ -462,7 +461,7 @@ final case class AwsBatchJob(jobDescriptor: BackendJobDescriptor, // WDL/CWL
def output(detail: JobDetail): String = {
val events: Seq[OutputLogEvent] = cloudWatchLogsClient.getLogEvents(GetLogEventsRequest.builder
// http://aws-java-sdk-javadoc.s3-website-us-west-2.amazonaws.com/latest/software/amazon/awssdk/services/batch/model/ContainerDetail.html#logStreamName--
.logGroupName("/aws/batch/job")
.logGroupName(runtimeAttributes.logsGroup)
.logStreamName(detail.container.logStreamName)
.startFromHead(true)
.build).events.asScala.toList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ package cromwell.backend.impl.aws
import scala.collection.mutable.ListBuffer
import cromwell.backend.BackendJobDescriptor
import cromwell.backend.io.JobPaths
import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, MountPoint, ResourceRequirement, ResourceType, Volume}
import software.amazon.awssdk.services.batch.model.{ContainerProperties, Host, KeyValuePair, LogConfiguration, MountPoint, ResourceRequirement, ResourceType, Volume}
import cromwell.backend.impl.aws.io.AwsBatchVolume

import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -128,8 +128,8 @@ trait AwsBatchJobDefinitionBuilder {
)
}

def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair]): String = {
val str = s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}"
def buildName(imageName: String, packedCommand: String, volumes: List[Volume], mountPoints: List[MountPoint], env: Seq[KeyValuePair], logsGroup: String): String = {
val str = s"$imageName:$packedCommand:${volumes.map(_.toString).mkString(",")}:${mountPoints.map(_.toString).mkString(",")}:${env.map(_.toString).mkString(",")}:$logsGroup"

val sha1 = MessageDigest.getInstance("SHA-1")
.digest( str.getBytes("UTF-8") )
Expand All @@ -148,26 +148,36 @@ trait AwsBatchJobDefinitionBuilder {
val packedCommand = packCommand("/bin/bash", "-c", cmdName)
val volumes = buildVolumes( context.runtimeAttributes.disks )
val mountPoints = buildMountPoints( context.runtimeAttributes.disks)
val logConfiguration = LogConfiguration.builder()
.logDriver("awslogs")
.options(
Map(
"awslogs-group" -> context.runtimeAttributes.logsGroup
).asJava
)
.build()
val jobDefinitionName = buildName(
context.runtimeAttributes.dockerImage,
packedCommand.mkString(","),
volumes,
mountPoints,
environment
environment,
context.runtimeAttributes.logsGroup
)

(builder
.command(packedCommand.asJava)
.resourceRequirements(
ResourceRequirement.builder()
.`type`(ResourceType.MEMORY)
.value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString)
.build(),
ResourceRequirement.builder()
.`type`(ResourceType.VCPU)
.value(context.runtimeAttributes.cpu.value.toString)
.build(),
)
.resourceRequirements(
ResourceRequirement.builder()
.`type`(ResourceType.MEMORY)
.value(context.runtimeAttributes.memory.to(MemoryUnit.MB).amount.toInt.toString)
.build(),
ResourceRequirement.builder()
.`type`(ResourceType.VCPU)
.value(context.runtimeAttributes.cpu.value.toString)
.build(),
)
.logConfiguration(logConfiguration)
.volumes( volumes.asJava)
.mountPoints( mountPoints.asJava)
.environment(environment.asJava),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import wom.format.MemorySize
import wom.types._
import wom.values._

import scala.jdk.CollectionConverters._
import scala.util.matching.Regex

/**
Expand All @@ -60,6 +61,8 @@ import scala.util.matching.Regex
* @param noAddress is there no address
* @param scriptS3BucketName the s3 bucket where the execution command or script will be written and, from there, fetched into the container and executed
* @param fileSystem the filesystem type, default is "s3"
* @param logsGroup the CloudWatch log group name to write logs to
* @param resourceTags a map of tags to add to the AWS Batch job submission
*/
case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
zones: Vector[String],
Expand All @@ -71,7 +74,10 @@ case class AwsBatchRuntimeAttributes(cpu: Int Refined Positive,
continueOnReturnCode: ContinueOnReturnCode,
noAddress: Boolean,
scriptS3BucketName: String,
fileSystem:String= "s3")
logsGroup: String,
resourceTags: Map[String, String],
fileSystem: String = "s3") {
}

object AwsBatchRuntimeAttributes {

Expand All @@ -92,6 +98,12 @@ object AwsBatchRuntimeAttributes {

private val MemoryDefaultValue = "2 GB"

private val logsGroupKey = "logsGroup"
private val logsGroupValidationInstance = new StringRuntimeAttributesValidation(logsGroupKey)
private val LogsGroupDefaultValue = WomString("/aws/batch/job")

private val resourceTagsKey = "resourceTags"

private def cpuValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Int Refined Positive] = CpuValidation.instance
.withDefault(CpuValidation.configDefaultWomValue(runtimeConfig) getOrElse CpuValidation.defaultMin)

Expand Down Expand Up @@ -123,6 +135,9 @@ object AwsBatchRuntimeAttributes {
private def noAddressValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[Boolean] = noAddressValidationInstance
.withDefault(noAddressValidationInstance.configDefaultWomValue(runtimeConfig) getOrElse NoAddressDefaultValue)

private def logsGroupValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[String] = logsGroupValidationInstance
.withDefault(logsGroupValidationInstance.configDefaultWomValue(runtimeConfig) getOrElse LogsGroupDefaultValue)

private def scriptS3BucketNameValidation(runtimeConfig: Option[Config]): RuntimeAttributesValidation[String] = {
ScriptS3BucketNameValidation(scriptS3BucketKey).withDefault(ScriptS3BucketNameValidation(scriptS3BucketKey)
.configDefaultWomValue(runtimeConfig).getOrElse( throw new RuntimeException( "scriptBucketName is required" )))
Expand All @@ -146,7 +161,8 @@ object AwsBatchRuntimeAttributes {
noAddressValidation(runtimeConfig),
dockerValidation,
queueArnValidation(runtimeConfig),
scriptS3BucketNameValidation(runtimeConfig)
scriptS3BucketNameValidation(runtimeConfig),
logsGroupValidation(runtimeConfig)
)
def validationsLocalBackend = StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
cpuValidation(runtimeConfig),
Expand All @@ -157,7 +173,8 @@ object AwsBatchRuntimeAttributes {
memoryMinValidation(runtimeConfig),
noAddressValidation(runtimeConfig),
dockerValidation,
queueArnValidation(runtimeConfig)
queueArnValidation(runtimeConfig),
logsGroupValidation(runtimeConfig)
)

configuration.fileSystem match {
Expand All @@ -181,7 +198,14 @@ object AwsBatchRuntimeAttributes {
case AWSBatchStorageSystems.s3 => RuntimeAttributesValidation.extract(scriptS3BucketNameValidation(runtimeAttrsConfig) , validatedRuntimeAttributes)
case _ => ""
}
val logsGroup: String = RuntimeAttributesValidation.extract(logsGroupValidation(runtimeAttrsConfig), validatedRuntimeAttributes)

val resourceTags: Map[String, String] = runtimeAttrsConfig.collect {
case config if config.hasPath(resourceTagsKey) =>
config.getObject(resourceTagsKey).entrySet().asScala
.map(e => e.getKey -> e.getValue.unwrapped().toString)
.toMap
}.getOrElse(Map.empty[String, String])

new AwsBatchRuntimeAttributes(
cpu,
Expand All @@ -194,6 +218,8 @@ object AwsBatchRuntimeAttributes {
continueOnReturnCode,
noAddress,
scriptS3BucketName,
logsGroup,
resourceTags,
fileSystem
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ class AwsBatchJobSpec extends TestKitSuite with AnyFlatSpecLike with Matchers wi
continueOnReturnCode = ContinueOnReturnCodeFlag(false),
noAddress = false,
scriptS3BucketName = "script-bucket",
fileSystem = "s3")
fileSystem = "s3",
logsGroup = "/aws/batch/job",
resourceTags = Map("tag" -> "value"))

private def generateBasicJob: AwsBatchJob = {
val job = AwsBatchJob(null, runtimeAttributes, "commandLine", script,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,31 @@ class AwsBatchRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeout
)))
}

val expectedDefaults = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),

val expectedDefaults = new AwsBatchRuntimeAttributes(
refineMV[Positive](1),
Vector("us-east-1a", "us-east-1b"),
MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
"arn:aws:batch:us-east-1:111222333444:job-queue/job-queue",
false,
ContinueOnReturnCodeSet(Set(0)),
false,
"my-stuff")

val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(refineMV[Positive](1), Vector("us-east-1a", "us-east-1b"),
"my-stuff",
"/Cromwell/job/",
Map("tag1" -> "value1"))

val expectedDefaultsLocalFS = new AwsBatchRuntimeAttributes(
refineMV[Positive](1),
Vector("us-east-1a", "us-east-1b"),
MemorySize(2, MemoryUnit.GB), Vector(AwsBatchWorkingDisk()),
"ubuntu:latest",
"arn:aws:batch:us-east-1:111222333444:job-queue/job-queue",
false,
ContinueOnReturnCodeSet(Set(0)),
false,
"",
"/Cromwell/job/",
Map(),
"local")

"AwsBatchRuntimeAttributes" should {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ object AwsBatchTestConfig {
| zones:["us-east-1a", "us-east-1b"]
| queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue"
| scriptBucketName: "my-bucket"
| logsGroup: "/Cromwell/job/"
| resourceTags {
| tag1: "value1"
| }
|}
|
|""".stripMargin
Expand Down Expand Up @@ -140,6 +144,7 @@ object AwsBatchTestConfigForLocalFS {
| zones:["us-east-1a", "us-east-1b"]
| queueArn: "arn:aws:batch:us-east-1:111222333444:job-queue/job-queue"
| scriptBucketName: ""
| logsGroup: "/Cromwell/job/"
|}
|
|""".stripMargin
Expand Down Expand Up @@ -190,4 +195,4 @@ object AwsBatchTestConfigForLocalFS {
val AwsBatchBackendNoDefaultConfig = ConfigFactory.parseString(NoDefaultsConfigString)
val AwsBatchBackendConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendConfig, AwsBatchGlobalConfig)
val NoDefaultsConfigurationDescriptor = BackendConfigurationDescriptor(AwsBatchBackendNoDefaultConfig, AwsBatchGlobalConfig)
}
}