Skip to content

Commit 539286a

Browse files
AlexDenisovcgestes
authored andcommitted
Inductor annotations (pytorch#130429)
Add NVTX annotations around training phases and buffer computations RFC/discussion: https://dev-discuss.pytorch.org/t/rfc-performance-profiling-at-scale-with-details-nvtx-annotations/2224 <img width="2160" alt="Screenshot 2024-07-10 at 11 48 04" src="https://github.com/pytorch/pytorch/assets/1175576/9ade139c-d393-473f-9b68-6c25da367dc4"> Pull Request resolved: pytorch#130429 Approved by: https://github.com/aorenste, https://github.com/eellison, https://github.com/albanD Co-authored-by: Cedric GESTES <cedric.gestes@flex.ai>
1 parent 24650c3 commit 539286a

File tree

8 files changed

+132
-0
lines changed

8 files changed

+132
-0
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Owner(s): ["module: inductor"]
2+
import torch
3+
import torch._inductor.config as inductor_config
4+
from torch._inductor.test_case import run_tests, TestCase
5+
from torch._inductor.utils import run_and_get_code
6+
from torch.testing._internal.triton_utils import requires_cuda
7+
8+
9+
class InductorAnnotationTestCase(TestCase):
10+
def get_code(self):
11+
def f(a, b):
12+
return a + b, a * b
13+
14+
a = torch.randn(5, device="cuda")
15+
b = torch.randn(5, device="cuda")
16+
f_comp = torch.compile(f)
17+
18+
_, code = run_and_get_code(f_comp, a, b)
19+
return code[0]
20+
21+
@requires_cuda
22+
def test_no_annotations(self):
23+
code = self.get_code()
24+
25+
self.assertTrue("from torch.cuda import nvtx" not in code)
26+
self.assertTrue("training_annotation" not in code)
27+
28+
@inductor_config.patch(annotate_training=True)
29+
@requires_cuda
30+
def test_training_annotation(self):
31+
code = self.get_code()
32+
33+
self.assertTrue("from torch.cuda import nvtx" in code)
34+
self.assertEqual(
35+
code.count("training_annotation = nvtx._device_range_start('inference')"), 1
36+
)
37+
self.assertEqual(code.count("nvtx._device_range_end(training_annotation)"), 1)
38+
39+
40+
if __name__ == "__main__":
41+
run_tests()

torch/_C/_nvtx.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@ def rangePop() -> int: ...
55
def rangeStartA(message: str) -> int: ...
66
def rangeEnd(int) -> None: ...
77
def markA(message: str) -> None: ...
8+
def deviceRangeStart(message: str, stream: int) -> object: ...
9+
def deviceRangeEnd(range_handle: object, stream: int) -> None: ...

torch/_inductor/codegen/wrapper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -772,6 +772,8 @@ def write_header(self) -> None:
772772
)
773773
except (AttributeError, ImportError):
774774
pass
775+
if config.annotate_training:
776+
self.header.writeline("from torch.cuda import nvtx")
775777

776778
def include_extra_header(self, header: str):
777779
pass
@@ -889,6 +891,11 @@ def {self.launcher_fn_name}(args):
889891
with self.prefix.indent():
890892
if config.triton.debug_sync_graph:
891893
self.prefix.writeline(V.graph.device_ops.synchronize())
894+
phase = V.graph.get_training_phase()
895+
if config.annotate_training:
896+
self.prefix.writeline(
897+
f"training_annotation = nvtx._device_range_start('{phase}')"
898+
)
892899
if V.graph.graph_inputs:
893900
lhs = ", ".join(V.graph.graph_input_names)
894901
if len(V.graph.graph_input_names) == 1:
@@ -1175,6 +1182,10 @@ def _generate(self, is_inference):
11751182
if config.triton.autotune_at_compile_time:
11761183
self.generate_and_run_autotune_block()
11771184

1185+
if config.annotate_training:
1186+
self.wrapper_call.writeline(
1187+
"nvtx._device_range_end(training_annotation)"
1188+
)
11781189
self.generate_return(output_refs)
11791190

11801191
self.finalize_prefix()

torch/_inductor/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,10 @@ def decide_compile_threads() -> int:
773773
)
774774

775775

776+
# Adds NVTX annotations aroung training phases
777+
annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
778+
779+
776780
# config specific to codegen/cpp.py
777781
class cpp:
778782
# set to torch.get_num_threads()

torch/_inductor/graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,13 @@ def set_current_device(self, device: torch.device) -> Iterator[None]:
519519
finally:
520520
self.current_device = prior
521521

522+
def get_training_phase(self) -> str:
523+
if self.is_inference:
524+
return "inference"
525+
if self.is_backward:
526+
return "backward"
527+
return "forward"
528+
522529
@staticmethod
523530
def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
524531
"""

torch/csrc/cuda/shared/nvtx.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,41 @@
66
#else
77
#include <nvToolsExt.h>
88
#endif
9+
#include <cuda_runtime.h>
910
#include <torch/csrc/utils/pybind.h>
1011

1112
namespace torch::cuda::shared {
1213

14+
struct RangeHandle {
15+
nvtxRangeId_t id;
16+
const char* msg;
17+
};
18+
19+
static void device_callback_range_end(void* userData) {
20+
RangeHandle* handle = ((RangeHandle*)userData);
21+
nvtxRangeEnd(handle->id);
22+
free((void*)handle->msg);
23+
free((void*)handle);
24+
}
25+
26+
static void device_nvtxRangeEnd(void* handle, std::intptr_t stream) {
27+
cudaLaunchHostFunc((cudaStream_t)stream, device_callback_range_end, handle);
28+
}
29+
30+
static void device_callback_range_start(void* userData) {
31+
RangeHandle* handle = ((RangeHandle*)userData);
32+
handle->id = nvtxRangeStartA(handle->msg);
33+
}
34+
35+
static void* device_nvtxRangeStart(const char* msg, std::intptr_t stream) {
36+
RangeHandle* handle = (RangeHandle*)calloc(sizeof(RangeHandle), 1);
37+
handle->msg = strdup(msg);
38+
handle->id = 0;
39+
cudaLaunchHostFunc(
40+
(cudaStream_t)stream, device_callback_range_start, (void*)handle);
41+
return handle;
42+
}
43+
1344
void initNvtxBindings(PyObject* module) {
1445
auto m = py::handle(module).cast<py::module>();
1546

@@ -23,6 +54,8 @@ void initNvtxBindings(PyObject* module) {
2354
nvtx.def("rangeStartA", nvtxRangeStartA);
2455
nvtx.def("rangeEnd", nvtxRangeEnd);
2556
nvtx.def("markA", nvtxMarkA);
57+
nvtx.def("deviceRangeStart", device_nvtxRangeStart);
58+
nvtx.def("deviceRangeEnd", device_nvtxRangeEnd);
2659
}
2760

2861
} // namespace torch::cuda::shared

torch/cuda/nvtx.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,38 @@ def range_end(range_id) -> None:
6666
_nvtx.rangeEnd(range_id)
6767

6868

69+
def _device_range_start(msg: str, stream: int = 0) -> object:
70+
"""
71+
Marks the start of a range with string message.
72+
It returns an opaque heap-allocated handle for this range
73+
to pass to the corresponding call to device_range_end().
74+
75+
A key difference between this and range_start is that the
76+
range_start marks the range right away, while _device_range_start
77+
marks the start of the range as soon as all the tasks on the
78+
CUDA stream are completed.
79+
80+
Returns: An opaque heap-allocated handle that should be passed to _device_range_end().
81+
82+
Args:
83+
msg (str): ASCII message to associate with the range.
84+
stream (int): CUDA stream id.
85+
"""
86+
return _nvtx.deviceRangeStart(msg, stream)
87+
88+
89+
def _device_range_end(range_handle: object, stream: int = 0) -> None:
90+
"""
91+
Mark the end of a range for a given range_handle as soon as all the tasks
92+
on the CUDA stream are completed.
93+
94+
Args:
95+
range_handle: an unique handle for the start range.
96+
stream (int): CUDA stream id.
97+
"""
98+
_nvtx.deviceRangeEnd(range_handle, stream)
99+
100+
69101
def mark(msg):
70102
"""
71103
Describe an instantaneous event that occurred at some point.

torch/utils/hipify/cuda_to_hip_mappings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5035,6 +5035,7 @@
50355035
"cudaLaunchCooperativeKernel",
50365036
("hipLaunchCooperativeKernel", CONV_EXEC, API_RUNTIME),
50375037
),
5038+
("cudaLaunchHostFunc", ("hipLaunchHostFunc", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED)),
50385039
(
50395040
"cudaSetupArgument",
50405041
("hipSetupArgument", CONV_EXEC, API_RUNTIME, HIP_UNSUPPORTED),
@@ -7965,6 +7966,7 @@
79657966
("nvtxRangePop", ("roctxRangePop", CONV_OTHER, API_ROCTX)),
79667967
("nvtxRangeStartA", ("roctxRangeStartA", CONV_OTHER, API_ROCTX)),
79677968
("nvtxRangeEnd", ("roctxRangeStop", CONV_OTHER, API_ROCTX)),
7969+
("nvtxRangeId_t", ("int", CONV_OTHER, API_ROCTX)),
79687970
("nvmlReturn_t", ("rsmi_status_t", CONV_OTHER, API_ROCMSMI)),
79697971
("NVML_SUCCESS", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)),
79707972
("NVML_P2P_CAPS_INDEX_READ", ("RSMI_STATUS_SUCCESS", CONV_OTHER, API_ROCMSMI)),

0 commit comments

Comments
 (0)