Skip to content

Azure Batch worker pool supports managed identity #5670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/reference/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ The following settings are available:
`azure.batch.pools.<name>.lowPriority`
: Enable the use of low-priority VMs (default: `false`).

`azure.batch.pools.<name>.managedIdentityId`
: :::{versionadded} 25.01.0-edge
:::
: Specify the pool has a managed identity attached. This will be passed to the task as the environment variable `NXF_AZURE_MI_CLIENT_ID`.

`azure.batch.pools.<name>.maxVmCount`
: Specify the max of virtual machine when using auto scale option.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import com.azure.compute.batch.models.ContainerConfiguration
import com.azure.compute.batch.models.ContainerRegistryReference
import com.azure.compute.batch.models.ContainerType
import com.azure.compute.batch.models.ElevationLevel
import com.azure.compute.batch.models.EnvironmentSetting
import com.azure.compute.batch.models.MetadataItem
import com.azure.compute.batch.models.MountConfiguration
import com.azure.compute.batch.models.NetworkConfiguration
Expand Down Expand Up @@ -452,6 +453,7 @@ class AzBatchService implements Closeable {
return key.size()>MAX_LEN ? key.substring(0,MAX_LEN) : key
}


protected BatchTaskCreateContent createTask(String poolId, String jobId, TaskRun task) {
assert poolId, 'Missing Azure Batch poolId argument'
assert jobId, 'Missing Azure Batch jobId argument'
Expand Down Expand Up @@ -492,37 +494,115 @@ class AzBatchService implements Closeable {

// Handle Fusion settings
final fusionEnabled = FusionHelper.isFusionEnabled((Session)Global.session)
final launcher = fusionEnabled ? FusionScriptLauncher.create(task.toTaskBean(), 'az') : null
String fusionCmd = null

if( fusionEnabled ) {
// Create the FusionScriptLauncher from the TaskBean
final taskBean = task.toTaskBean()
final launcher = FusionScriptLauncher.create(taskBean, 'az')

// Create the adapter that will manage the Fusion env with pool options
// TaskRun doesn't implement FusionAwareTask directly, so we need a wrapper
final adapter = new AzureBatchFusionAdapter(new AzFusionTaskWrapper(task), launcher, pool?.opts)

// Add container options
opts += "--privileged "
for( Map.Entry<String,String> it : launcher.fusionEnv() ) {

// Add all environment variables from the adapter
for( Map.Entry<String,String> it : adapter.getEnvironment() ) {
opts += "-e $it.key=$it.value "
}

// Get the fusion submit command
final List<String> cmdList = adapter.fusionSubmitCli()
fusionCmd = cmdList ? String.join(' ', cmdList) : null
}

// Create container settings

final containerOpts = new BatchTaskContainerSettings(container)
.setContainerRunOptions(opts)

// submit command line
final String cmd = fusionEnabled
? launcher.fusionSubmitCli(task).join(' ')
: "bash -o pipefail -c 'bash ${TaskRun.CMD_RUN} 2>&1 | tee ${TaskRun.CMD_LOG}'"
final String cmd = fusionEnabled && fusionCmd
? fusionCmd
: "bash -o pipefail -c 'bash ${TaskRun.CMD_RUN} 2>&1 | tee ${TaskRun.CMD_LOG}'"
// cpus and memory
final slots = computeSlots(task, pool)
// max wall time
final constraints = constraints(task)

log.trace "[AZURE BATCH] Submitting task: $taskId, cpus=${task.config.getCpus()}, mem=${task.config.getMemory()?:'-'}, slots: $slots"

// Add environment variables for managed identity if configured
final env = [:] as Map<String,String>
if( pool?.opts?.managedIdentityId ) {
env.put('AZCOPY_AUTO_LOGIN_TYPE', 'MSI') // azcopy
env.put('AZCOPY_MSI_CLIENT_ID', pool.opts.managedIdentityId) // azcopy
Comment on lines +538 to +539
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These looks unrelated

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what makes azcopy work.

The first iteration of this PR did not include Fusion.

env.put('FUSION_AZ_MSI_CLIENT_ID', pool.opts.managedIdentityId) // fusion
}

return createBatchTaskContent(
taskId,
cmd,
userIdentity(pool.opts.privileged, pool.opts.runAs, AutoUserScope.TASK),
containerOpts,
resourceFileUrls(task, sas),
outputFileUrls(task, sas),
slots,
constraints,
env
)
}

/**
* Create task constraints based on the task configuration
*
* @param task The task run to create constraints for
* @return The BatchTaskConstraints object
*/
protected BatchTaskConstraints constraints(TaskRun task) {
final constraints = new BatchTaskConstraints()
if( task.config.getTime() )
constraints.setMaxWallClockTime( Duration.of(task.config.getTime().toMillis(), ChronoUnit.MILLIS) )
return constraints
}

log.trace "[AZURE BATCH] Submitting task: $taskId, cpus=${task.config.getCpus()}, mem=${task.config.getMemory()?:'-'}, slots: $slots"
/**
* Create a BatchTaskCreateContent object with the given parameters
*
* @param taskId Task ID
* @param cmd Command to run
* @param userIdentity User identity
* @param containerSettings Container settings
* @param resourceFiles Resource files
* @param outputFiles Output files
* @param slots Required slots
* @param constraints Task constraints
* @param env Environment variables
* @return The BatchTaskCreateContent object
*/
protected BatchTaskCreateContent createBatchTaskContent(
String taskId,
String cmd,
UserIdentity userIdentity,
BatchTaskContainerSettings containerSettings,
List<ResourceFile> resourceFiles,
List<OutputFile> outputFiles,
int slots,
BatchTaskConstraints constraints,
Map<String,String> env) {

log.trace "[AZURE BATCH] Task details: id=$taskId, slots=$slots, constraints=${constraints?.maxWallClockTime}"

return new BatchTaskCreateContent(taskId, cmd)
.setUserIdentity(userIdentity(pool.opts.privileged, pool.opts.runAs, AutoUserScope.TASK))
.setContainerSettings(containerOpts)
.setResourceFiles(resourceFileUrls(task, sas))
.setOutputFiles(outputFileUrls(task, sas))
.setUserIdentity(userIdentity)
.setContainerSettings(containerSettings)
.setResourceFiles(resourceFiles)
.setOutputFiles(outputFiles)
.setRequiredSlots(slots)
.setConstraints(constraints)
.setEnvironmentSettings(env.collect { name, value ->
new EnvironmentSetting(name).setValue(value)
})
}

AzTaskKey runTask(String poolId, String jobId, TaskRun task) {
Expand Down Expand Up @@ -583,6 +663,13 @@ class AzBatchService implements Closeable {
List<OutputFile> result = new ArrayList<>(20)
result << destFile(TaskRun.CMD_EXIT, task.workDir, sas)
result << destFile(TaskRun.CMD_LOG, task.workDir, sas)
result << destFile(TaskRun.CMD_OUTFILE, task.workDir, sas)
result << destFile(TaskRun.CMD_ERRFILE, task.workDir, sas)
result << destFile(TaskRun.CMD_SCRIPT, task.workDir, sas)
result << destFile(TaskRun.CMD_RUN, task.workDir, sas)
result << destFile(TaskRun.CMD_STAGE, task.workDir, sas)
result << destFile(TaskRun.CMD_TRACE, task.workDir, sas)
result << destFile(TaskRun.CMD_ENV, task.workDir, sas)
return result
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ class AzFileCopyStrategy extends SimpleFileCopyStrategy {
final result = new StringBuilder()
final copy = environment ? new LinkedHashMap<String,String>(environment) : new LinkedHashMap<String,String>()
copy.remove('PATH')
copy.put('PATH', '$PWD/.nextflow-bin:$AZ_BATCH_NODE_SHARED_DIR/bin/:$PATH')
copy.put('AZCOPY_LOG_LOCATION', '$PWD/.azcopy_log')
copy.put('PATH', '$AZ_BATCH_TASK_DIR/.nextflow-bin:$AZ_BATCH_NODE_SHARED_DIR/bin/:$PATH')
copy.put('AZCOPY_LOG_LOCATION', '$AZ_BATCH_TASK_DIR/.azcopy_log')
copy.put('AZCOPY_JOB_PLAN_LOCATION', '$AZ_BATCH_TASK_DIR/.azcopy_log')
copy.put('AZ_SAS', sasToken)

// finally render the environment
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright 2021, Microsoft Corp
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package nextflow.cloud.azure.batch

import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import nextflow.executor.Executor
import nextflow.fusion.FusionAwareTask
import nextflow.fusion.FusionConfig
import nextflow.fusion.FusionScriptLauncher
import nextflow.processor.TaskRun

/**
* Adapter class that wraps a TaskRun to implement the FusionAwareTask interface
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Slf4j
@CompileStatic
class AzFusionTaskWrapper implements FusionAwareTask {

private TaskRun task
private FusionScriptLauncher fusionLauncher
private Boolean fusionEnabled

AzFusionTaskWrapper(TaskRun task) {
this.task = task
}

@Override
TaskRun getTask() {
return task
}

@Override
boolean fusionEnabled() {
return true // Always true since we only create this wrapper when fusion is enabled
}

@Override
FusionConfig fusionConfig() {
return FusionConfig.getConfig()
}

@Override
FusionScriptLauncher fusionLauncher() {
if (fusionLauncher == null) {
fusionLauncher = FusionScriptLauncher.create(task.toTaskBean(), task.workDir.scheme)
}
return fusionLauncher
}

@Override
List<String> fusionSubmitCli() {
return fusionLauncher().fusionSubmitCli(task)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2021, Microsoft Corp
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package nextflow.cloud.azure.batch

import groovy.transform.CompileStatic
import nextflow.cloud.azure.config.AzPoolOpts
import nextflow.processor.TaskBean
import nextflow.processor.TaskRun

/**
* Azure extension of the TaskBean that includes Azure Batch pool options
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@CompileStatic
class AzTaskBean extends TaskBean {

/**
* The Azure pool options that apply to this task
*/
final AzPoolOpts poolOpts

/**
* Create a new Azure TaskBean from a TaskRun with pool options
*
* @param task The TaskRun to get base configuration from
* @param poolOpts The Azure pool options
*/
AzTaskBean(TaskRun task, AzPoolOpts poolOpts) {
super(task)
this.poolOpts = poolOpts
}

/**
* Get the Azure pool options
*
* @return The Azure pool options
*/
AzPoolOpts getPoolOpts() {
return poolOpts
}

/**
* Get the managed identity client ID if configured
*
* @return The managed identity client ID or null if not configured
*/
String getManagedIdentityId() {
return poolOpts?.managedIdentityId
}
}
Loading
Loading