Skip to content

Commit

Permalink
[dask] raise more informative error for duplicates in 'machines' (fixes
Browse files Browse the repository at this point in the history
#4057) (#4059)

* [dask] raise more informative error for duplicates in 'machines'

* uncomment

* avoid test failure

* Revert "avoid test failure"

This reverts commit 9442bdf.
  • Loading branch information
jameslamb authored Mar 10, 2021
1 parent b75a43a commit 296397d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
"""
machine_addresses = machines.split(",")

if len(set(machine_addresses)) != len(machine_addresses):
raise ValueError(f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination.")

machine_to_port = defaultdict(set)
for address in machine_addresses:
host, port = address.split(":")
Expand Down
12 changes: 12 additions & 0 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,7 @@ def test_machines_should_be_used_if_provided(task, output):
client.rebalance()

n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model = dask_model_factory(
n_estimators=5,
Expand All @@ -1134,6 +1135,17 @@ def test_machines_should_be_used_if_provided(task, output):
s.bind(('127.0.0.1', open_ports[0]))
dask_model.fit(dX, dy, group=dg)

# an informative error should be raised if "machines" has duplicates
one_open_port = lgb.dask._find_random_open_port()
dask_model.set_params(
machines=",".join([
"127.0.0.1:" + str(one_open_port)
for _ in range(n_workers)
])
)
with pytest.raises(ValueError, match="Found duplicates in 'machines'"):
dask_model.fit(dX, dy, group=dg)


@pytest.mark.parametrize(
"classes",
Expand Down

0 comments on commit 296397d

Please sign in to comment.