Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[launcher] parse hostfile via regex and added error checks #2626

Merged
merged 3 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 40 additions & 19 deletions deepspeed/launcher/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

import os
import re
import sys
import json
import base64
Expand Down Expand Up @@ -182,25 +183,45 @@ def fetch_hostfile(hostfile_path):

# e.g., worker-0 slots=16
with open(hostfile_path, 'r') as fd:
resource_pool = collections.OrderedDict()
for line in fd.readlines():
line = line.strip()
if line == '':
# skip empty lines
continue
try:
hostname, slots = line.split()
_, slot_count = slots.split("=")
slot_count = int(slot_count)
except ValueError as err:
logger.error("Hostfile is not formatted correctly, unable to "
"proceed with training.")
raise err
if hostname in resource_pool:
logger.error("Hostfile contains duplicate hosts, unable to "
"proceed with training.")
raise ValueError(f"host {hostname} is already defined")
resource_pool[hostname] = slot_count
hostfile_text = fd.readlines()

return _parse_hostfile(hostfile_text)


def _parse_hostfile(hostfile_lines):
# Regex matches one or more non-whitespace characters (\S+) at the start of
# the line, followed by one or more whitespace characters (\s+), followed
# by the string "slots=", followed by one or more digits (\d+).
pattern = r'^(\S+)\s+slots=(\d+)'

resource_pool = collections.OrderedDict()

for line in hostfile_lines:
line = line.strip()
match = re.search(pattern, line)
if line.startswith("#") or line == "":
# hostfile comment or empty line, ignore
continue
elif match:
host = match.group(1)
num_slots = int(match.group(2))
if host in resource_pool:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
f"Hostfile contains multiple entries for {host}, unable to proceed with launching"
)
resource_pool[host] = num_slots
else:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
"Hostfile contains a bad entry: {line}, unable to proceed with launching"
)

if len(resource_pool) == 0:
logger.error(f"Bad hostfile text: {hostfile_lines}")
raise ValueError(
"Hostfile is empty or not formatted correctly, unable to proceed with launching."
)

return resource_pool

Expand Down
67 changes: 67 additions & 0 deletions tests/unit/launcher/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,70 @@ def test_num_plus_parser():
dsrun.main(args="--num_nodes 1 --num_gpus 1 -e localhost foo.py".split())
with pytest.raises(ValueError):
dsrun.main(args="--num_gpus 1 -e localhost foo.py".split())


def test_hostfile_good():
# good hostfile w. empty lines and comment
hostfile = """
worker-1 slots=2
worker-2 slots=2

localhost slots=1
123.23.12.10 slots=2

#worker-1 slots=3
# this is a comment

"""
r = dsrun._parse_hostfile(hostfile.splitlines())
assert "worker-1" in r
assert "worker-2" in r
assert "localhost" in r
assert "123.23.12.10" in r
assert r["worker-1"] == 2
assert r["worker-2"] == 2
assert r["localhost"] == 1
assert r["123.23.12.10"] == 2
assert len(r) == 4


def test_hostfiles_bad():
# duplicate host
hostfile = """
worker-1 slots=2
worker-2 slots=1
worker-1 slots=1
"""
with pytest.raises(ValueError):
dsrun._parse_hostfile(hostfile.splitlines())

# incorrect whitespace
hostfile = """
this is bad slots=1
"""
with pytest.raises(ValueError):
dsrun._parse_hostfile(hostfile.splitlines())

# no whitespace
hostfile = """
missingslots
"""
with pytest.raises(ValueError):
dsrun._parse_hostfile(hostfile.splitlines())

# empty
hostfile = """
"""
with pytest.raises(ValueError):
dsrun._parse_hostfile(hostfile.splitlines())

# mix of good/bad
hostfile = """
worker-1 slots=2
this is bad slots=1
worker-2 slots=4
missingslots

"""
with pytest.raises(ValueError):
dsrun._parse_hostfile(hostfile.splitlines())