Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Manage NVIDIA GPU slots in local executor #5850

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/reference/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,12 @@ The following settings are available:
`executor.exitReadTimeout`
: Determines how long to wait before returning an error status when a process is terminated but the `.exitcode` file does not exist or is empty (default: `270 sec`). Used only by grid executors.

`executor.gpus`
: :::{versionadded} 25.04.0
:::
: *Used only by the `local` executor.*
: The maximum number of NVIDIA GPUs made available by the underlying system. When this setting is enabled, each local task is assigned GPUs based on their `accelerator` request, using the `CUDA_VISIBLE_DEVICES` environment variable.

`executor.jobName`
: Determines the name of jobs submitted to the underlying cluster executor e.g. `executor.jobName = { "$task.name - $task.hash" }`. Make sure the resulting job name matches the validation constraints of the underlying batch scheduler.
: This setting is supported by the following executors: Bridge, Condor, Flux, HyperQueue, Lsf, Moab, Nqsii, Oar, PBS, PBS Pro, SGE, SLURM and Google Batch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {

private volatile TaskResult result

List<Integer> gpuSlots

LocalTaskHandler(TaskRun task, LocalExecutor executor) {
super(task)
// create the task handler
Expand Down Expand Up @@ -142,11 +144,13 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {
final workDir = task.workDir.toFile()
final logFile = new File(workDir, TaskRun.CMD_LOG)

return new ProcessBuilder()
final pb = new ProcessBuilder()
.redirectErrorStream(true)
.redirectOutput(logFile)
.directory(workDir)
.command(cmd)
applyGpuSlots(pb)
return pb
}

protected ProcessBuilder fusionProcessBuilder() {
Expand All @@ -162,10 +166,18 @@ class LocalTaskHandler extends TaskHandler implements FusionAwareTask {

final logPath = Files.createTempFile('nf-task','.log')

return new ProcessBuilder()
final pb = new ProcessBuilder()
.redirectErrorStream(true)
.redirectOutput(logPath.toFile())
.command(List.of('sh','-c', cmd))
applyGpuSlots(pb)
return pb
}

protected void applyGpuSlots(ProcessBuilder pb) {
if( !gpuSlots )
return
pb.environment().put('CUDA_VISIBLE_DEVICES', gpuSlots.join(','))
}

protected ProcessBuilder createLaunchProcessBuilder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import groovy.transform.PackageScope
import groovy.util.logging.Slf4j
import nextflow.Session
import nextflow.exception.ProcessUnrecoverableException
import nextflow.executor.local.LocalTaskHandler
import nextflow.util.Duration
import nextflow.util.MemoryUnit
import nextflow.util.TrackingSemaphore

/**
* Task polling monitor specialized for local execution. It manages tasks scheduling
Expand Down Expand Up @@ -58,6 +60,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
*/
private final long maxMemory

/**
* Number of `free` GPUs available to execute pending tasks
*/
private TrackingSemaphore availGpus

/**
* Total number of CPUs available in the system
*/
private final int maxGpus

/**
* Create the task polling monitor with the provided named parameters object.
* <p>
Expand All @@ -74,6 +86,8 @@ class LocalPollingMonitor extends TaskPollingMonitor {
super(params)
this.availCpus = maxCpus = params.cpus as int
this.availMemory = maxMemory = params.memory as long
this.maxGpus = params.gpus as int
this.availGpus = new TrackingSemaphore(maxGpus)
assert availCpus>0, "Local avail `cpus` attribute cannot be zero"
assert availMemory>0, "Local avail `memory` attribute cannot zero"
}
Expand All @@ -98,14 +112,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {

final int cpus = configCpus(session,name)
final long memory = configMem(session,name)
final int gpus = configGpus(session,name)
final int size = session.getQueueSize(name, OS.getAvailableProcessors())

log.debug "Creating local task monitor for executor '$name' > cpus=$cpus; memory=${new MemoryUnit(memory)}; capacity=$size; pollInterval=$pollInterval; dumpInterval=$dumpInterval"
log.debug "Creating local task monitor for executor '$name' > cpus=$cpus; memory=${new MemoryUnit(memory)}; gpus=$gpus; capacity=$size; pollInterval=$pollInterval; dumpInterval=$dumpInterval"

new LocalPollingMonitor(
name: name,
cpus: cpus,
memory: memory,
gpus: gpus,
session: session,
capacity: size,
pollInterval: pollInterval,
Expand All @@ -128,6 +144,11 @@ class LocalPollingMonitor extends TaskPollingMonitor {
(session.getExecConfigProp(name, 'memory', OS.getTotalPhysicalMemorySize()) as MemoryUnit).toBytes()
}

@PackageScope
static int configGpus(Session session, String name) {
return session.getExecConfigProp(name, 'gpus', 0) as int
}

/**
* @param handler
* A {@link TaskHandler} instance
Expand All @@ -149,6 +170,16 @@ class LocalPollingMonitor extends TaskPollingMonitor {
handler.task.getConfig()?.getMemory()?.toBytes() ?: 1L
}

/**
* @param handler
* A {@link TaskHandler} instance
* @return
* The number of gpus requested to execute the specified task
*/
private static int gpus(TaskHandler handler) {
handler.task.getConfig()?.getAccelerator()?.getRequest() ?: 0
}

/**
* Determines if a task can be submitted for execution checking if the resources required
* (cpus and memory) match the amount of avail resource
Expand All @@ -174,9 +205,14 @@ class LocalPollingMonitor extends TaskPollingMonitor {
if( taskMemory>maxMemory)
throw new ProcessUnrecoverableException("Process requirement exceeds available memory -- req: ${new MemoryUnit(taskMemory)}; avail: ${new MemoryUnit(maxMemory)}")

final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory
final taskGpus = gpus(handler)
if( taskGpus>maxGpus )
throw new ProcessUnrecoverableException("Process requirement exceeds available GPUs -- req: $taskGpus; avail: $maxGpus")

final availGpus0 = availGpus.availablePermits()
final result = super.canSubmit(handler) && taskCpus <= availCpus && taskMemory <= availMemory && taskGpus <= availGpus0
if( !result && log.isTraceEnabled( ) ) {
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)}"
log.trace "Task `${handler.task.name}` cannot be scheduled -- taskCpus: $taskCpus <= availCpus: $availCpus && taskMemory: ${new MemoryUnit(taskMemory)} <= availMemory: ${new MemoryUnit(availMemory)} && taskGpus: $taskGpus <= availGpus: ${availGpus0}"
}
return result
}
Expand All @@ -192,6 +228,7 @@ class LocalPollingMonitor extends TaskPollingMonitor {
super.submit(handler)
availCpus -= cpus(handler)
availMemory -= mem(handler)
((LocalTaskHandler) handler).gpuSlots = availGpus.acquire(gpus(handler))
}

/**
Expand All @@ -204,11 +241,13 @@ class LocalPollingMonitor extends TaskPollingMonitor {
* {@code true} when the task is successfully removed from polling queue,
* {@code false} otherwise
*/
@Override
protected boolean remove(TaskHandler handler) {
final result = super.remove(handler)
if( result ) {
availCpus += cpus(handler)
availMemory += mem(handler)
availGpus.release(((LocalTaskHandler) handler).gpuSlots ?: Collections.<Integer>emptyList())
}
return result
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2013-2024, Seqera Labs
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package nextflow.util

import java.util.concurrent.Semaphore

import groovy.transform.CompileStatic

/**
* Specialized semaphore that keeps track of which slots
* are being used.
*
* @author Ben Sherman <bentshermann@gmail.com>
*/
@CompileStatic
class TrackingSemaphore {
private final Semaphore semaphore
private final Map<Integer,Boolean> availIds

TrackingSemaphore(int permits) {
semaphore = new Semaphore(permits)
availIds = new HashMap<>(permits)
for( int i=0; i<permits; i++ )
availIds.put(i, true)
}

int availablePermits() {
return semaphore.availablePermits()
}

List<Integer> acquire(int permits) {
semaphore.acquire(permits)
final result = new ArrayList<Integer>(permits)
for( final entry : availIds.entrySet() ) {
if( entry.getValue() ) {
entry.setValue(false)
result.add(entry.getKey())
}
if( result.size() == permits )
break
}
return result
}

void release(List<Integer> ids) {
semaphore.release(ids.size())
for( id in ids )
availIds.put(id, true)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.lang.management.ManagementFactory
import com.sun.management.OperatingSystemMXBean
import nextflow.Session
import nextflow.exception.ProcessUnrecoverableException
import nextflow.executor.local.LocalTaskHandler
import nextflow.util.MemoryUnit
import spock.lang.Specification
/**
Expand All @@ -38,14 +39,15 @@ class LocalPollingMonitorTest extends Specification {
cpus: 10,
capacity: 20,
memory: _20_GB,
gpus: 0,
session: session,
name: 'local',
pollInterval: 100
)

def task = new TaskRun()
task.config = new TaskConfig(cpus: 3, memory: MemoryUnit.of('2GB'))
def handler = Mock(TaskHandler)
def handler = Mock(LocalTaskHandler)
handler.getTask() >> { task }

expect:
Expand Down Expand Up @@ -86,14 +88,15 @@ class LocalPollingMonitorTest extends Specification {
cpus: 10,
capacity: 10,
memory: _20_GB,
gpus: 0,
session: session,
name: 'local',
pollInterval: 100
)

def task = new TaskRun()
task.config = new TaskConfig(cpus: 4, memory: MemoryUnit.of('8GB'))
def handler = Mock(TaskHandler)
def handler = Mock(LocalTaskHandler)
handler.getTask() >> { task }
handler.canForkProcess() >> true
handler.isReady() >> true
Expand Down Expand Up @@ -132,14 +135,15 @@ class LocalPollingMonitorTest extends Specification {
cpus: 1,
capacity: 1,
memory: _20_GB,
gpus: 0,
session: session,
name: 'local',
pollInterval: 100
)

def task = new TaskRun()
task.config = new TaskConfig(cpus: 1, memory: MemoryUnit.of('8GB'))
def handler = Mock(TaskHandler)
def handler = Mock(LocalTaskHandler)
handler.getTask() >> { task }
handler.canForkProcess() >> true
handler.isReady() >> true
Expand Down Expand Up @@ -167,6 +171,7 @@ class LocalPollingMonitorTest extends Specification {
cpus: 10,
capacity: 20,
memory: _20_GB,
gpus: 0,
session: session,
name: 'local',
pollInterval: 100
Expand Down Expand Up @@ -195,6 +200,7 @@ class LocalPollingMonitorTest extends Specification {
cpus: 10,
capacity: 20,
memory: _20_GB,
gpus: 0,
session: session,
name: 'local',
pollInterval: 100
Expand Down