Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(AzBatchService): Allow Azure Batch tasks to be submitted to different pools #5766

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ import nextflow.cloud.types.CloudMachineInfo
import nextflow.cloud.types.PriceModel
import nextflow.fusion.FusionHelper
import nextflow.fusion.FusionScriptLauncher
import nextflow.processor.TaskProcessor
import nextflow.processor.TaskRun
import nextflow.util.CacheHelper
import nextflow.util.MemoryUnit
Expand All @@ -111,7 +110,7 @@ class AzBatchService implements Closeable {

AzConfig config

Map<TaskProcessor,String> allJobIds = new HashMap<>(50)
Map<AzJobKey,String> allJobIds = new HashMap<>(50)

AzBatchService(AzBatchExecutor executor) {
assert executor
Expand Down Expand Up @@ -355,17 +354,26 @@ class AzBatchService implements Closeable {
}

synchronized String getOrCreateJob(String poolId, TaskRun task) {
final mapKey = task.processor
// Use the same job Id for the same Process,PoolId pair
// The Pool is added to allow using different queue names (corresponding
// a pool id) for the same process. See also
// https://github.com/nextflow-io/nextflow/pull/5766
final mapKey = new AzJobKey(task.processor, poolId)
if( allJobIds.containsKey(mapKey)) {
return allJobIds[mapKey]
}
final jobId = createJob0(poolId,task)
// add to the map
allJobIds[mapKey] = jobId
return jobId
}

protected String createJob0(String poolId, TaskRun task) {
log.debug "[AZURE BATCH] created job for ${task.processor.name} with pool ${poolId}"
// create a batch job
final jobId = makeJobId(task)
final content = new BatchJobCreateContent(jobId, new BatchPoolInfo(poolId: poolId))
apply(() -> client.createJob(content))
// add to the map
allJobIds[mapKey] = jobId
return jobId
}

String makeJobId(TaskRun task) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2013-2024, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package nextflow.cloud.azure.batch

import groovy.transform.Canonical
import groovy.transform.CompileStatic
import nextflow.processor.TaskProcessor
/**
* Model a Batch job key for caching purposes
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
@Canonical
@CompileStatic
class AzJobKey {
final TaskProcessor processor
final String poolId
}
Original file line number Diff line number Diff line change
Expand Up @@ -739,4 +739,54 @@ class AzBatchServiceTest extends Specification {
[managedIdentity: [clientId: 'client-123']] | 'client-123'
}

def 'should cache job id' () {
given:
def exec = Mock(AzBatchExecutor)
def service = Spy(new AzBatchService(exec))
and:
def p1 = Mock(TaskProcessor)
def p2 = Mock(TaskProcessor)
def t1 = Mock(TaskRun) { getProcessor()>>p1 }
def t2 = Mock(TaskRun) { getProcessor()>>p2 }
def t3 = Mock(TaskRun) { getProcessor()>>p2 }

when:
def result = service.getOrCreateJob('foo',t1)
then:
1 * service.createJob0('foo',t1) >> 'job1'
and:
result == 'job1'

// second time is cached
when:
result = service.getOrCreateJob('foo',t1)
then:
0 * service.createJob0('foo',t1) >> null
and:
result == 'job1'

// changing pool id returns a new job id
when:
result = service.getOrCreateJob('bar',t1)
then:
1 * service.createJob0('bar',t1) >> 'job2'
and:
result == 'job2'

// changing process returns a new job id
when:
result = service.getOrCreateJob('bar',t2)
then:
1 * service.createJob0('bar',t2) >> 'job3'
and:
result == 'job3'

// change task with the same process, return cached job id
when:
result = service.getOrCreateJob('bar',t3)
then:
0 * service.createJob0('bar',t3) >> null
and:
result == 'job3'
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright 2013-2024, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package nextflow.cloud.azure.batch

import nextflow.processor.TaskProcessor
import spock.lang.Specification

/**
*
* @author Paolo Di Tommaso <paolo.ditommaso@gmail.com>
*/
class AzJobKeyTest extends Specification {

def 'should validate equals and hashcode' () {
given:
def p1 = Mock(TaskProcessor)
def p2 = Mock(TaskProcessor)
def k1 = new AzJobKey(p1, 'foo')
def k2 = new AzJobKey(p1, 'foo')
def k3 = new AzJobKey(p2, 'foo')
def k4 = new AzJobKey(p1, 'bar')

expect:
k1 == k2
k1 != k3
k1 != k4
and:
k1.hashCode() == k2.hashCode()
k1.hashCode() != k3.hashCode()
k1.hashCode() != k4.hashCode()
}

}