From 8c56c25d84764dccf5979cb40a318bc3ccfac31a Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Mon, 19 Dec 2022 14:41:53 -0800 Subject: [PATCH] [launcher] parse hostfile via regex and added error checks (#2626) --- deepspeed/launcher/runner.py | 59 +++++++++++++++++++---------- tests/unit/launcher/test_run.py | 67 +++++++++++++++++++++++++++++++++ 2 files changed, 107 insertions(+), 19 deletions(-) diff --git a/deepspeed/launcher/runner.py b/deepspeed/launcher/runner.py index 9d015abb2c14..1201dfc3cc09 100755 --- a/deepspeed/launcher/runner.py +++ b/deepspeed/launcher/runner.py @@ -7,6 +7,7 @@ """ import os +import re import sys import json import base64 @@ -182,25 +183,45 @@ def fetch_hostfile(hostfile_path): # e.g., worker-0 slots=16 with open(hostfile_path, 'r') as fd: - resource_pool = collections.OrderedDict() - for line in fd.readlines(): - line = line.strip() - if line == '': - # skip empty lines - continue - try: - hostname, slots = line.split() - _, slot_count = slots.split("=") - slot_count = int(slot_count) - except ValueError as err: - logger.error("Hostfile is not formatted correctly, unable to " - "proceed with training.") - raise err - if hostname in resource_pool: - logger.error("Hostfile contains duplicate hosts, unable to " - "proceed with training.") - raise ValueError(f"host {hostname} is already defined") - resource_pool[hostname] = slot_count + hostfile_text = fd.readlines() + + return _parse_hostfile(hostfile_text) + + +def _parse_hostfile(hostfile_lines): + # Regex matches one or more non-whitespace characters (\S+) at the start of + # the line, followed by one or more whitespace characters (\s+), followed + # by the string "slots=", followed by one or more digits (\d+). + pattern = r'^(\S+)\s+slots=(\d+)' + + resource_pool = collections.OrderedDict() + + for line in hostfile_lines: + line = line.strip() + match = re.search(pattern, line) + if line.startswith("#") or line == "": + # hostfile comment or empty line, ignore + continue + elif match: + host = match.group(1) + num_slots = int(match.group(2)) + if host in resource_pool: + logger.error(f"Bad hostfile text: {hostfile_lines}") + raise ValueError( + f"Hostfile contains multiple entries for {host}, unable to proceed with launching" + ) + resource_pool[host] = num_slots + else: + logger.error(f"Bad hostfile text: {hostfile_lines}") + raise ValueError( + "Hostfile contains a bad entry: {line}, unable to proceed with launching" + ) + + if len(resource_pool) == 0: + logger.error(f"Bad hostfile text: {hostfile_lines}") + raise ValueError( + "Hostfile is empty or not formatted correctly, unable to proceed with launching." + ) return resource_pool diff --git a/tests/unit/launcher/test_run.py b/tests/unit/launcher/test_run.py index f2b0a8b2018a..4677e1e3025a 100644 --- a/tests/unit/launcher/test_run.py +++ b/tests/unit/launcher/test_run.py @@ -106,3 +106,70 @@ def test_num_plus_parser(): dsrun.main(args="--num_nodes 1 --num_gpus 1 -e localhost foo.py".split()) with pytest.raises(ValueError): dsrun.main(args="--num_gpus 1 -e localhost foo.py".split()) + + +def test_hostfile_good(): + # good hostfile w. empty lines and comment + hostfile = """ + worker-1 slots=2 + worker-2 slots=2 + + localhost slots=1 + 123.23.12.10 slots=2 + + #worker-1 slots=3 + # this is a comment + + """ + r = dsrun._parse_hostfile(hostfile.splitlines()) + assert "worker-1" in r + assert "worker-2" in r + assert "localhost" in r + assert "123.23.12.10" in r + assert r["worker-1"] == 2 + assert r["worker-2"] == 2 + assert r["localhost"] == 1 + assert r["123.23.12.10"] == 2 + assert len(r) == 4 + + +def test_hostfiles_bad(): + # duplicate host + hostfile = """ + worker-1 slots=2 + worker-2 slots=1 + worker-1 slots=1 + """ + with pytest.raises(ValueError): + dsrun._parse_hostfile(hostfile.splitlines()) + + # incorrect whitespace + hostfile = """ + this is bad slots=1 + """ + with pytest.raises(ValueError): + dsrun._parse_hostfile(hostfile.splitlines()) + + # no whitespace + hostfile = """ + missingslots + """ + with pytest.raises(ValueError): + dsrun._parse_hostfile(hostfile.splitlines()) + + # empty + hostfile = """ + """ + with pytest.raises(ValueError): + dsrun._parse_hostfile(hostfile.splitlines()) + + # mix of good/bad + hostfile = """ + worker-1 slots=2 + this is bad slots=1 + worker-2 slots=4 + missingslots + + """ + with pytest.raises(ValueError): + dsrun._parse_hostfile(hostfile.splitlines())