Skip to content

Commit dfc50d4

Browse files
add GPU functionality to dockerCompose
1 parent 8c773f1 commit dfc50d4

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

src/spec-node/dockerCompose.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ async function startContainer(params: DockerResolverParameters, buildParams: Doc
386386
// Save override docker-compose file to disk.
387387
// Persisted folder is a path that will be maintained between sessions
388388
// Note: As a fallback, persistedFolder is set to the build's tmpDir() directory
389-
const overrideFilePath = await writeFeaturesComposeOverrideFile(updatedImageName, currentImageName, mergedConfig, config, versionPrefix, imageDetails, service, idLabels, params.additionalMounts, persistedFolder, featuresStartOverrideFilePrefix, buildCLIHost, output);
389+
const overrideFilePath = await writeFeaturesComposeOverrideFile(updatedImageName, currentImageName, mergedConfig, config, versionPrefix, imageDetails, service, idLabels, params.additionalMounts, persistedFolder, featuresStartOverrideFilePrefix, buildCLIHost, params, output);
390390
if (overrideFilePath) {
391391
// Add file path to override file as parameter
392392
composeGlobalArgs.push('-f', overrideFilePath);
@@ -449,9 +449,10 @@ async function writeFeaturesComposeOverrideFile(
449449
overrideFilePath: string,
450450
overrideFilePrefix: string,
451451
buildCLIHost: CLIHost,
452+
params: DockerResolverParameters,
452453
output: Log,
453454
) {
454-
const composeOverrideContent = await generateFeaturesComposeOverrideContent(updatedImageName, originalImageName, mergedConfig, config, versionPrefix, imageDetails, service, additionalLabels, additionalMounts);
455+
const composeOverrideContent = await generateFeaturesComposeOverrideContent(updatedImageName, originalImageName, mergedConfig, config, versionPrefix, imageDetails, service, additionalLabels, additionalMounts, params);
455456
const overrideFileHasContents = !!composeOverrideContent && composeOverrideContent.length > 0 && composeOverrideContent.trim() !== '';
456457
if (overrideFileHasContents) {
457458
output.write(`Docker Compose override file for creating container:\n${composeOverrideContent}`);
@@ -470,6 +471,12 @@ async function writeFeaturesComposeOverrideFile(
470471
}
471472
}
472473

474+
async function checkDockerSupportForGPU(params: DockerCLIParameters | DockerResolverParameters): Promise<Boolean> {
475+
const result = await dockerCLI(params, 'info', '-f', '{{.Runtimes.nvidia}}');
476+
const runtimeFound = result.stdout.includes('nvidia-container-runtime');
477+
return runtimeFound;
478+
}
479+
473480
async function generateFeaturesComposeOverrideContent(
474481
updatedImageName: string,
475482
originalImageName: string,
@@ -480,6 +487,7 @@ async function generateFeaturesComposeOverrideContent(
480487
service: any,
481488
additionalLabels: string[],
482489
additionalMounts: Mount[],
490+
params: DockerResolverParameters,
483491
) {
484492
const overrideImage = updatedImageName !== originalImageName;
485493

@@ -501,6 +509,19 @@ async function generateFeaturesComposeOverrideContent(
501509
const userCommand = overrideCommand ? [] : composeCommand /* $ already escaped. */
502510
|| (composeEntrypoint ? [/* Ignore image CMD per docker-compose.yml spec. */] : ((await imageDetails()).Config.Cmd || []).map(c => c.replace(/\$/g, '$$$$'))); // $ > $$ to escape docker-compose.yml's interpolation.
503511

512+
const hasGpuRequirement = config.hostRequirements?.gpu;
513+
const addGpuCapability = hasGpuRequirement && await checkDockerSupportForGPU(params);
514+
if (hasGpuRequirement && hasGpuRequirement !== 'optional' && !addGpuCapability) {
515+
throw Error('Unable to detect GPU yet a GPU was required - consider marking it as "optional"');
516+
}
517+
const gpuResources = addGpuCapability ? '' : `
518+
deploy:
519+
resources:
520+
reservations:
521+
devices:
522+
- capabilities: [gpu]
523+
`;
524+
504525
return `${versionPrefix}services:
505526
'${config.service}':${overrideImage ? `
506527
image: ${updatedImageName}` : ''}
@@ -524,7 +545,7 @@ while sleep 1 & wait $$!; do :; done", "-"${userEntrypoint.map(a => `, ${JSON.st
524545
volumes:${mounts.map(m => `
525546
- ${m.source}:${m.target}`).join('')}` : ''}${volumeMounts.length ? `
526547
volumes:${volumeMounts.map(m => `
527-
${m.source}:${m.external ? '\n external: true' : ''}`).join('')}` : ''}
548+
${m.source}:${m.external ? '\n external: true' : ''}`).join('')}` : ''}${gpuResources}
528549
`;
529550
}
530551

0 commit comments

Comments
 (0)