Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.

Use Rabit tracker get_host_ip('auto') to pick best tracker IP address #40

Merged
merged 1 commit into from
Jun 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions dask_xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from toolz import assoc, first
from tornado import gen

from .tracker import RabitTracker
from .tracker import RabitTracker, get_host_ip

try:
import sparse
Expand All @@ -35,6 +35,8 @@ def parse_host_port(address):

def start_tracker(host, n_workers):
""" Start Rabit tracker """
if host is None:
host = get_host_ip('auto')
env = {"DMLC_NUM_WORKER": n_workers}
rabit = RabitTracker(hostIP=host, nslave=n_workers)
env.update(rabit.slave_envs())
Expand Down Expand Up @@ -116,6 +118,7 @@ def train_part(
result = None
evals_result = None
finally:
logger.info("Finalizing Rabit, Rank %d", xgb.rabit.get_rank())
xgb.rabit.finalize()
return result, evals_result

Expand Down Expand Up @@ -205,10 +208,9 @@ def _train(
ncores = yield client.scheduler.ncores() # Number of cores per worker

# Start the XGBoost tracker on the Dask scheduler
host, port = parse_host_port(client.scheduler.address)
env = yield client._run_on_scheduler(
start_tracker, host.strip("/:"), len(worker_map)
)
env = yield client._run_on_scheduler(start_tracker,
None,
len(worker_map))

# Tell each worker to train on the chunks/parts that it has locally
futures = [
Expand Down
2 changes: 2 additions & 0 deletions dask_xgboost/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,15 @@ def __init__(self, hostIP, nslave, port=9091, port_end=9999):
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
for port in range(port, port_end):
try:
logging.info('Binding Rabit tracker %s:%d', hostIP, port)
sock.bind((hostIP, port))
self.port = port
break
except socket.error as e:
if e.errno in [98, 48]:
continue
else:
logging.error(e, exc_info=True)
raise
sock.listen(256)
self.sock = sock
Expand Down