diff --git a/deepspeed/launcher/constants.py b/deepspeed/launcher/constants.py index 09dd1bdd1def..3d49a7565ea6 100644 --- a/deepspeed/launcher/constants.py +++ b/deepspeed/launcher/constants.py @@ -4,6 +4,7 @@ PDSH_MAX_FAN_OUT = 1024 OPENMPI_LAUNCHER = 'openmpi' +JSRUN_LAUNCHER = 'jsrun' MPICH_LAUNCHER = 'mpich' SLURM_LAUNCHER = 'slurm' MVAPICH_LAUNCHER = 'mvapich' diff --git a/deepspeed/launcher/multinode_runner.py b/deepspeed/launcher/multinode_runner.py index 6027d1076e80..ff2fdc540967 100644 --- a/deepspeed/launcher/multinode_runner.py +++ b/deepspeed/launcher/multinode_runner.py @@ -169,6 +169,60 @@ def get_cmd(self, environment, active_resources): return mpirun_cmd + export_cmd + python_exec + [self.user_script ] + self.user_arguments +class JSRunner(MultiNodeRunner): + def __init__(self, args, world_info_base64, resource_pool): + super().__init__(args, world_info_base64) + self.resource_pool = resource_pool + # Hard coded for Summit + self.add_export('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5') + + def backend_exists(self): + #TODO: if IB is available we should suggestion mvapich + #This ompi check will still work for jsrun since spectrum-mpi is based on ompi + return shutil.which('ompi_info') + + @property + def name(self): + return "jsrun" + + def validate_args(self): + super().validate_args() + #TODO: Allow for include/exclude at node-level but not gpu-level + if self.args.include != "" or self.args.exclude != "": + raise ValueError( + f"{self.name} backend does not support worker include/exclusion") + if self.args.num_nodes != -1 or self.args.num_gpus != -1: + raise ValueError( + f"{self.name} backend does not support limiting num nodes/gpus") + + def get_cmd(self, environment, active_resources): + total_process_count = sum(self.resource_pool.values()) + + jsrun_cmd = [ + 'jsrun', + '-n', + f'{total_process_count}', + '-c', + f'{7}', + '-g', + f'{1}', + '-a', + f'{1}', + + ] + split(self.args.launcher_args) + + export_cmd = [] + for k, v in self.exports.items(): + export_cmd += ['-E', "{}={}".format(k, v)] + + python_exec = [] + if not self.args.no_python: + python_exec = [sys.executable, "-u"] + if self.args.module: + python_exec.append("-m") + + return jsrun_cmd + export_cmd + python_exec + [self.user_script + ] + self.user_arguments class MPICHRunner(MultiNodeRunner): def __init__(self, args, world_info_base64, resource_pool): diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 044f6ff03365..ed076caae7e3 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -18,8 +18,8 @@ import signal import time -from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner -from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER +from .multinode_runner import PDSHRunner, OpenMPIRunner, MVAPICHRunner, SlurmRunner, MPICHRunner, JSRunner +from .constants import PDSH_LAUNCHER, OPENMPI_LAUNCHER, MVAPICH_LAUNCHER, SLURM_LAUNCHER, MPICH_LAUNCHER, JSRUN_LAUNCHER from ..constants import TORCH_DISTRIBUTED_DEFAULT_PORT from ..nebula.constants import NEBULA_EXPORT_ENVS from ..utils import logger @@ -511,6 +511,8 @@ def main(args=None): runner = PDSHRunner(args, world_info_base64) elif args.launcher == OPENMPI_LAUNCHER: runner = OpenMPIRunner(args, world_info_base64, resource_pool) + elif args.launcher == JSRUN_LAUNCHER: + runner = JSRunner(args, world_info_base64, resource_pool) elif args.launcher == MPICH_LAUNCHER: runner = MPICHRunner(args, world_info_base64, resource_pool) elif args.launcher == MVAPICH_LAUNCHER: