Skip to content

Commit 36449ea

Browse files
authored
(torch/elastic) add fqdn hostname to error printout (pytorch#66182) (pytorch#66662)
Summary: Pull Request resolved: pytorch#66182 closes pytorch#63174 Does a few things: 1. adds hostname to the error report 2. moves the "root cause" section to the end (presumably since the logs are being "tailed" we want the root cause to appear at the end) 3. moves redundant error info logging to debug 4. makes the border max 60 char in length and justifies left for the header NOTE: YOU HAVE TO annotate your main function with torch.distributed.elastic.multiprocessing.errors.record, otherwise no traceback is printed (this is because python exception propagation does NOT work out of the both for IPC - hence the extra record annotation). Test Plan: Sample ``` ============================================================ run_script_path FAILED ------------------------------------------------------------ Failures: <NO_OTHER_FAILURES> ------------------------------------------------------------ Root Cause (first observed failure): [0]: time : 2021-10-05_17:37:22 host : devvm4955.prn0.facebook.com rank : 0 (local_rank: 0) exitcode : 1 (pid: 3296201) error_file: /home/kiuk/tmp/elastic/none_3_lsytqe/attempt_0/0/error.json traceback : Traceback (most recent call last): File "/tmp/jetter.xr3_x6qq/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 372, in wrapper return f(*args, **kwargs) File "main.py", line 28, in main raise RuntimeError(args.throws) RuntimeError: foobar ============================================================ ``` Reviewed By: cbalioglu, aivanou Differential Revision: D31416492 fbshipit-source-id: 0aeaf6e634e23ce0ea7f6a03b12c8a9ac57246e9
1 parent b544cbd commit 36449ea

File tree

6 files changed

+84
-66
lines changed

6 files changed

+84
-66
lines changed

test/distributed/elastic/multiprocessing/api_test.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import time
1616
import unittest
1717
from itertools import product
18-
from typing import Dict, List, Union, Callable
18+
from typing import Callable, Dict, List, Union
1919
from unittest import mock
2020
from unittest.mock import patch
2121

@@ -24,25 +24,25 @@
2424
from torch.distributed.elastic.multiprocessing import ProcessFailure, start_processes
2525
from torch.distributed.elastic.multiprocessing.api import (
2626
MultiprocessContext,
27-
SignalException,
2827
RunProcsResult,
28+
SignalException,
2929
Std,
3030
_validate_full_rank,
31-
to_map,
3231
_wrap,
32+
to_map,
3333
)
3434
from torch.distributed.elastic.multiprocessing.errors.error_handler import _write_error
3535
from torch.testing._internal.common_utils import (
36+
IS_IN_CI,
37+
IS_MACOS,
38+
IS_WINDOWS,
3639
NO_MULTIPROCESSING_SPAWN,
3740
TEST_WITH_ASAN,
38-
TEST_WITH_TSAN,
3941
TEST_WITH_DEV_DBG_ASAN,
40-
IS_IN_CI,
41-
IS_WINDOWS,
42-
IS_MACOS,
42+
TEST_WITH_TSAN,
43+
run_tests,
4344
sandcastle_skip_if,
4445
)
45-
from torch.testing._internal.common_utils import run_tests
4646

4747

4848
class RunProcResultsTest(unittest.TestCase):
@@ -224,6 +224,7 @@ def start_processes_zombie_test(
224224

225225
# tests incompatible with tsan or asan
226226
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
227+
227228
class StartProcessesTest(unittest.TestCase):
228229
def setUp(self):
229230
self.test_dir = tempfile.mkdtemp(prefix=f"{self.__class__.__name__}_")
@@ -251,12 +252,15 @@ def assert_pids_noexist(self, pids: Dict[int, int]):
251252

252253
def test_to_map(self):
253254
local_world_size = 2
254-
self.assertEqual({0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size))
255+
self.assertEqual(
256+
{0: Std.OUT, 1: Std.OUT}, to_map(Std.OUT, local_world_size)
257+
)
255258
self.assertEqual(
256259
{0: Std.NONE, 1: Std.OUT}, to_map({1: Std.OUT}, local_world_size)
257260
)
258261
self.assertEqual(
259-
{0: Std.ERR, 1: Std.OUT}, to_map({0: Std.ERR, 1: Std.OUT}, local_world_size)
262+
{0: Std.ERR, 1: Std.OUT},
263+
to_map({0: Std.ERR, 1: Std.OUT}, local_world_size),
260264
)
261265

262266
def test_invalid_log_dir(self):
@@ -382,9 +386,7 @@ def test_void_function(self):
382386
results = pc.wait(period=0.1)
383387
self.assertEqual({0: None, 1: None}, results.return_values)
384388

385-
@sandcastle_skip_if(
386-
TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan"
387-
)
389+
@sandcastle_skip_if(TEST_WITH_DEV_DBG_ASAN, "tests incompatible with asan")
388390
def test_function_large_ret_val(self):
389391
# python multiprocessing.queue module uses pipes and actually PipedQueues
390392
# This means that if a single object is greater than a pipe size
@@ -439,7 +441,9 @@ def test_function_raise(self):
439441
self.assertEqual(1, failure.exitcode)
440442
self.assertEqual("<N/A>", failure.signal_name())
441443
self.assertEqual(pc.pids()[0], failure.pid)
442-
self.assertEqual(os.path.join(log_dir, "0", "error.json"), error_file)
444+
self.assertEqual(
445+
os.path.join(log_dir, "0", "error.json"), error_file
446+
)
443447
self.assertEqual(
444448
int(error_file_data["message"]["extraInfo"]["timestamp"]),
445449
int(failure.timestamp),
@@ -541,17 +545,22 @@ def test_multiprocessing_context_poll_raises_exception(self):
541545
run_result = mp_context._poll()
542546
self.assertEqual(1, len(run_result.failures))
543547
failure = run_result.failures[0]
544-
self.assertEqual("Signal 1 (SIGHUP) received by PID 123", failure.message)
548+
self.assertEqual(
549+
"Signal 1 (SIGHUP) received by PID 123", failure.message
550+
)
545551

546552

547553
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
548554
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS):
555+
549556
class StartProcessesListTest(StartProcessesTest):
550557
########################################
551558
# start_processes as binary tests
552559
########################################
553560
def test_function(self):
554-
for start_method, redirs in product(self._start_methods, redirects_oss_test()):
561+
for start_method, redirs in product(
562+
self._start_methods, redirects_oss_test()
563+
):
555564
with self.subTest(start_method=start_method, redirs=redirs):
556565
pc = start_processes(
557566
name="echo",
@@ -644,6 +653,7 @@ def test_binary_redirect_and_tee(self):
644653

645654
# tests incompatible with tsan or asan, the redirect functionality does not work on macos or windows
646655
if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS or IS_IN_CI):
656+
647657
class StartProcessesNotCITest(StartProcessesTest):
648658
def test_wrap_bad(self):
649659
none = ""
@@ -796,7 +806,8 @@ def test_function_exit(self):
796806
self.assertEqual(pc.pids()[0], failure.pid)
797807
self.assertEqual("<N/A>", error_file)
798808
self.assertEqual(
799-
f"Process failed with exitcode {FAIL}", failure.message
809+
"To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html",
810+
failure.message,
800811
)
801812
self.assertLessEqual(failure.timestamp, int(time.time()))
802813

test/distributed/elastic/multiprocessing/errors/api_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,10 @@ def test_process_failure_no_error_file(self):
115115
pf = self.failure_without_error_file(exitcode=138)
116116
self.assertEqual("<N/A>", pf.signal_name())
117117
self.assertEqual("<N/A>", pf.error_file)
118-
self.assertEqual("Process failed with exitcode 138", pf.message)
118+
self.assertEqual(
119+
"To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html",
120+
pf.message,
121+
)
119122

120123
def test_child_failed_error(self):
121124
pf0 = self.failure_with_error_file(exception=SentinelError("rank 0"))
@@ -134,7 +137,7 @@ def test_child_failed_error(self):
134137
rank: 0 (local_rank: 0)
135138
exitcode: 1 (pid: 997)
136139
error_file: /tmp/ApiTesttbb37ier/error.json
137-
msg: "SentinelError: rank 0"
140+
traceback: "SentinelError: rank 0"
138141
=============================================
139142
Other Failures:
140143
[1]:
@@ -148,7 +151,7 @@ def test_child_failed_error(self):
148151
rank: 2 (local_rank: 0)
149152
exitcode: 138 (pid: 997)
150153
error_file: <N/A>
151-
msg: "Process failed with exitcode 138"
154+
traceback: To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
152155
*********************************************
153156
"""
154157
print(ex)

torch/distributed/elastic/agent/server/local_elastic_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121
WorkerState,
2222
)
2323
from torch.distributed.elastic.metrics.api import prof
24-
from torch.distributed.elastic.multiprocessing import start_processes, PContext
24+
from torch.distributed.elastic.multiprocessing import PContext, start_processes
2525
from torch.distributed.elastic.utils import macros
2626
from torch.distributed.elastic.utils.logging import get_logger
2727

28+
2829
log = get_logger()
2930

3031

torch/distributed/elastic/multiprocessing/errors/__init__.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
import json
5252
import os
5353
import signal
54+
import socket
5455
import time
5556
import warnings
5657
from dataclasses import dataclass, field
@@ -109,7 +110,7 @@ def __post_init__(self):
109110
try:
110111
with open(self.error_file, "r") as fp:
111112
self.error_file_data = json.load(fp)
112-
log.info(
113+
log.debug(
113114
f"User process failed with error data: {json.dumps(self.error_file_data, indent=2)}"
114115
)
115116
self.message, self.timestamp = self._get_error_data(
@@ -130,7 +131,7 @@ def __post_init__(self):
130131
f" received by PID {self.pid}"
131132
)
132133
else:
133-
self.message = f"Process failed with exitcode {self.exitcode}"
134+
self.message = "To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html"
134135

135136
def _get_error_data(self, error_file_data: Dict[str, Any]) -> Tuple[str, int]:
136137
message = error_file_data["message"]
@@ -162,24 +163,24 @@ def timestamp_isoformat(self):
162163
GlobalRank = int
163164

164165
_FAILURE_FORMAT_TEMPLATE = """[${idx}]:
165-
time: ${time}
166-
rank: ${rank} (local_rank: ${local_rank})
167-
exitcode: ${exitcode} (pid: ${pid})
166+
time : ${time}
167+
host : ${hostname}
168+
rank : ${rank} (local_rank: ${local_rank})
169+
exitcode : ${exitcode} (pid: ${pid})
168170
error_file: ${error_file}
169-
msg: ${message}"""
171+
traceback : ${message}"""
170172

171173
# extra new lines before and after are intentional
172174
_MSG_FORMAT_TEMPLATE = """
173175
${boarder}
174176
${title}
175177
${section}
176-
Root Cause:
177-
${root_failure}
178-
${section}
179-
Other Failures:
178+
Failures:
180179
${other_failures}
181-
${boarder}
182-
"""
180+
${section}
181+
Root Cause (first observed failure):
182+
${root_failure}
183+
${boarder}"""
183184

184185

185186
class ChildFailedError(Exception):
@@ -230,8 +231,8 @@ def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]:
230231
rank = min(self.failures.keys(), key=lambda r: self.failures[r].timestamp)
231232
return rank, self.failures[rank]
232233

233-
def format_msg(self, boarder_delim="*", section_delim="="):
234-
title = f" {self.name} FAILED "
234+
def format_msg(self, boarder_delim="=", section_delim="-"):
235+
title = f"{self.name} FAILED"
235236
root_rank, root_failure = self.get_first_failure()
236237

237238
root_failure_fmt: str = ""
@@ -246,11 +247,11 @@ def format_msg(self, boarder_delim="*", section_delim="="):
246247
other_failures_fmt.append(fmt)
247248

248249
# upper boundary on width
249-
width = min(width, 80)
250+
width = min(width, 60)
250251

251252
return Template(_MSG_FORMAT_TEMPLATE).substitute(
252253
boarder=boarder_delim * width,
253-
title=title.center(width),
254+
title=title,
254255
section=section_delim * width,
255256
root_failure=root_failure_fmt,
256257
other_failures="\n".join(other_failures_fmt or [" <NO_OTHER_FAILURES>"]),
@@ -279,6 +280,7 @@ def _format_failure(
279280
fmt = Template(_FAILURE_FORMAT_TEMPLATE).substitute(
280281
idx=idx,
281282
time=failure.timestamp_isoformat(),
283+
hostname=socket.getfqdn(),
282284
rank=rank,
283285
local_rank=failure.local_rank,
284286
exitcode=failure.exitcode,
@@ -292,32 +294,6 @@ def _format_failure(
292294
return fmt, width
293295

294296

295-
def _no_error_file_warning_msg(rank: int, failure: ProcessFailure) -> str:
296-
msg = [
297-
"CHILD PROCESS FAILED WITH NO ERROR_FILE",
298-
f"Child process {failure.pid} (local_rank {rank}) FAILED (exitcode {failure.exitcode})",
299-
f"Error msg: {failure.message}",
300-
f"Without writing an error file to {failure.error_file}.",
301-
"While this DOES NOT affect the correctness of your application,",
302-
"no trace information about the error will be available for inspection.",
303-
"Consider decorating your top level entrypoint function with",
304-
"torch.distributed.elastic.multiprocessing.errors.record. Example:",
305-
"",
306-
r" from torch.distributed.elastic.multiprocessing.errors import record",
307-
"",
308-
r" @record",
309-
r" def trainer_main(args):",
310-
r" # do train",
311-
]
312-
width = 0
313-
for line in msg:
314-
width = max(width, len(line))
315-
316-
boarder = "*" * width
317-
header = "CHILD PROCESS FAILED WITH NO ERROR_FILE".center(width)
318-
return "\n".join(["\n", boarder, header, boarder, *msg, boarder])
319-
320-
321297
def record(
322298
fn: Callable[..., T], error_handler: Optional[ErrorHandler] = None
323299
) -> Callable[..., T]:
@@ -372,7 +348,13 @@ def wrapper(*args, **kwargs):
372348
if failure.error_file != _NOT_AVAILABLE:
373349
error_handler.dump_error_file(failure.error_file, failure.exitcode)
374350
else:
375-
warnings.warn(_no_error_file_warning_msg(rank, failure))
351+
log.info(
352+
(
353+
f"local_rank {rank} FAILED with no error file."
354+
f" Decorate your entrypoint fn with @record for traceback info."
355+
f" See: https://pytorch.org/docs/stable/elastic/errors.html"
356+
)
357+
)
376358
raise
377359
except Exception as e:
378360
error_handler.record_exception(e)

torch/distributed/elastic/multiprocessing/errors/error_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def dump_error_file(self, rootcause_error_file: str, error_code: int = 0):
107107
else:
108108
rootcause_error["message"]["errorCode"] = error_code
109109

110-
log.info(
110+
log.debug(
111111
f"child error file ({rootcause_error_file}) contents:\n"
112112
f"{json.dumps(rootcause_error, indent=2)}"
113113
)

torch/distributed/run.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,27 @@ def train():
304304
305305
if should_checkpoint:
306306
save_checkpoint(checkpoint_path)
307+
308+
9. (Recommended) On worker errors, this tool will summarize the details of the error
309+
(e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
310+
is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
311+
error summary print out, you must decorate your main entrypoint function in your
312+
training script as shown in the example below. If not decorated, then the summary
313+
will not include the traceback of the exception and will only contain the exitcode.
314+
For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
315+
316+
::
317+
318+
from torch.distributed.elastic.multiprocessing.errors import record
319+
320+
@record
321+
def main():
322+
# do train
323+
pass
324+
325+
if __name__ == "__main__":
326+
main()
327+
307328
"""
308329
import logging
309330
import os
@@ -597,7 +618,7 @@ def config_from_args(args) -> Tuple[LaunchConfig, Union[Callable, str], List[str
597618
if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
598619
omp_num_threads = 1
599620
log.warning(
600-
f"*****************************************\n"
621+
f"\n*****************************************\n"
601622
f"Setting OMP_NUM_THREADS environment variable for each process to be "
602623
f"{omp_num_threads} in default, to avoid your system being overloaded, "
603624
f"please further tune the variable for optimal performance in "

0 commit comments

Comments
 (0)