Skip to content

Add start_trace and stop_trace API in profiler #8743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_async_closures.py"
run_test "$CDIR/test_hlo_metadata.py"
run_test "$CDIR/test_profiler.py"
run_test "$CDIR/test_profiler_session.py"
run_test "$CDIR/pjrt/test_runtime.py"
run_test "$CDIR/pjrt/test_runtime_single_proc_gpu.py"
run_test "$CDIR/pjrt/test_runtime_multi_gpu.py"
Expand Down
70 changes: 70 additions & 0 deletions test/test_profiler_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import glob
import os
from absl.testing import absltest

import torch
import torch_xla.debug.profiler as xp


def _run_computation():

class M(torch.nn.Module):

def __init__(self):
super(M, self).__init__()
self.fc1 = torch.nn.Linear(10, 5)
self.fc2 = torch.nn.Linear(5, 10)

def forward(self, x):
with xp.Trace('fc1'):
x = self.fc1(x)
with xp.Trace('fc2'):
x = self.fc2(x)
return x

m = M()
m = m.to('xla')
x = torch.randn(10, 10).to('xla')
for _ in range(20):
y = m(x)
y.cpu()


class TestProfilerSession(absltest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a long-duration profile test?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think we need a long-duration profile test in torchxla. The goal of this PR is to provide better usability to users, the capability of profiler is out of scope of this PR (Should be in the underlying tsl library)


def setUp(self):
self.server = xp.start_server(8005)

def test_start_and_stop(self):
tempdir = self.create_tempdir().full_path
xp.start_trace(tempdir)
_run_computation()
xp.stop_trace()
tempdir2 = self.create_tempdir().full_path
xp.start_trace(tempdir2)
_run_computation()
xp.stop_trace()
files = glob.glob(
os.path.join(tempdir, '**', '*.xplane.pb'), recursive=True)
self.assertEqual(len(files), 1)
files = glob.glob(
os.path.join(tempdir2, '**', '*.xplane.pb'), recursive=True)
self.assertEqual(len(files), 1)

def test_error_double_start(self):
tempdir = self.create_tempdir().full_path
xp.start_trace(tempdir)
try:
with self.assertRaisesRegex(RuntimeError,
"Only one profile may be run at a time."):
xp.start_trace(tempdir)
finally:
xp.stop_trace()

def test_error_stop_before_start(self):
with self.assertRaisesRegex(RuntimeError, "No profile started"):
xp.stop_trace()


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ run_xla_hlo_debug python3 "$TEST_CDIR/scan/test_scan_debug.py"
python3 "$TEST_CDIR/test_pallas.py" -v
python3 "$TEST_CDIR/test_pallas_spmd.py"
XLA_DISABLE_FUNCTIONALIZATION=1 python3 "$TEST_CDIR/test_pallas_spmd.py"
python3 "$TEST_CDIR/test_profiler_session.py"
python3 "$TEST_CDIR/test_multi_queries_paged_attention_kernel.py"
python3 "$TEST_CDIR/test_ragged_paged_attention_kernel.py"
python3 "$TEST_CDIR/test_input_output_aliases.py"
Expand Down
18 changes: 18 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,24 @@ void BuildProfilerSubmodule(py::module* m) {
[](const std::string& name) -> std::unique_ptr<torch::lazy::ScopePusher> {
return absl::make_unique<torch::lazy::ScopePusher>(name);
});

// Profiler Session Definition.
py::class_<runtime::profiler::TslProfilerSessionWrapper,
std::unique_ptr<runtime::profiler::TslProfilerSessionWrapper>>
profiler_session_class(profiler, "TslProfilerSessionWrapper");
profiler_session_class.def(
py::init(&runtime::profiler::TslProfilerSessionWrapper::Create));
profiler_session_class.def("stop", [](py::object self) -> py::bytes {
std::string xspace_str =
py::cast<runtime::profiler::TslProfilerSessionWrapper*>(self)->Stop();
return py::bytes(xspace_str);
});
profiler_session_class.def("export", [](py::object self, py::bytes xspace,
const std::string& dump_dir) {
const std::string xspace_str = xspace.cast<std::string>();
py::cast<runtime::profiler::TslProfilerSessionWrapper*>(self)->Export(
xspace_str, dump_dir);
});
}

class PyLoweringContext {
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,10 @@ cc_library(
"@com_google_absl//absl/status",
"@xla//xla/backends/profiler/plugin:profiler_c_api_hdrs",
"@xla//xla/backends/profiler/plugin:plugin_tracer",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
"@tsl//tsl/profiler/lib:profiler_factory",
"@tsl//tsl/profiler/lib:profiler_session",
"@xla//xla/tsl/profiler/rpc:profiler_server_impl",
"@xla//xla/tsl/profiler/rpc/client:capture_profile",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
26 changes: 26 additions & 0 deletions torch_xla/csrc/runtime/profiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "xla/backends/profiler/plugin/plugin_tracer.h"
#include "xla/backends/profiler/plugin/profiler_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_profiler_extension.h"
#include "xla/pjrt/status_casters.h"
#include "xla/tsl/profiler/rpc/client/capture_profile.h"
#include "xla/tsl/profiler/rpc/profiler_server.h"

Expand Down Expand Up @@ -45,6 +46,31 @@ void ProfilerServer::Start(int port) {

ProfilerServer::~ProfilerServer() {}

const std::string TslProfilerSessionWrapper::Stop() const {
tensorflow::profiler::XSpace xspace;
// Disables the ProfilerSession
xla::ThrowIfError(this->session->CollectData(&xspace));
std::string xspace_str = xspace.SerializeAsString();
return xspace_str;
}

void TslProfilerSessionWrapper::Export(
const std::string& xspace_str, const std::string& tensorboard_dir) const {
tensorflow::profiler::XSpace xspace_proto;
xspace_proto.ParseFromString(xspace_str);
xla::ThrowIfError(
tsl::profiler::ExportToTensorBoard(xspace_proto, tensorboard_dir,
/* also_export_trace_json= */ true));
}

std::unique_ptr<TslProfilerSessionWrapper> TslProfilerSessionWrapper::Create() {
tensorflow::ProfileOptions options = tsl::ProfilerSession::DefaultOptions();
options.set_python_tracer_level(1);
options.set_enable_hlo_proto(true);
return absl::make_unique<runtime::profiler::TslProfilerSessionWrapper>(
tsl::ProfilerSession::Create(options));
}

absl::Status Trace(
const char* service_addr, const char* logdir, int duration_ms,
int num_tracing_attempts,
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/csrc/runtime/profiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "tsl/profiler/lib/profiler_session.h"
#include "xla/pjrt/c/pjrt_c_api.h"

namespace torch_xla {
Expand All @@ -23,6 +24,25 @@ class ProfilerServer {
std::unique_ptr<Impl> impl_;
};

// Profiler session implementation is based on OpenXLA, we cannot reuse
// the Python binding since it's using nanobind and torch_xla is using pybind11.
// https://github.com/openxla/xla/blob/main/xla/python/profiler.cc
class TslProfilerSessionWrapper {
public:
static std::unique_ptr<TslProfilerSessionWrapper> Create();

explicit TslProfilerSessionWrapper(
std::unique_ptr<tsl::ProfilerSession> session)
: session(std::move(session)) {}

void Export(const std::string& xspace_str,
const std::string& tensorboard_dir) const;
const std::string Stop() const;

private:
std::unique_ptr<tsl::ProfilerSession> session;
};

absl::Status Trace(
const char* service_addr, const char* logdir, int duration_ms,
int num_tracing_attempts,
Expand Down
64 changes: 64 additions & 0 deletions torch_xla/debug/profiler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import functools
import os
import threading

import torch_xla
import torch_xla.core.xla_model as xm

Expand Down Expand Up @@ -183,3 +185,65 @@ def wrapper_trace_me(*args, **kwargs):
return wrapper_trace_me

return decorator_trace_me


# The profiler implementation is based on JAX implementation
# https://github.com/jax-ml/jax/blob/main/jax/_src/profiler.py
class _ProfileState:

def __init__(self):
self.profile_session = None
self.log_dir = None
self.create_perfetto_link = False
self.create_perfetto_trace = False
self.lock = threading.Lock()

def reset(self):
_profile_state.profile_session = None
_profile_state.create_perfetto_link = False
_profile_state.create_perfetto_trace = False
_profile_state.log_dir = None


_profile_state = _ProfileState()


def start_trace(log_dir: os.PathLike | str) -> None:
"""Starts a profiler trace.

The trace will capture CPU, GPU, and/or TPU activity, including Python
functions and PyTorch/XLA on-device operations. Use :func:`stop_trace` to end
the trace and save the results to ``log_dir``.

The resulting trace can be viewed with TensorBoard. Note that TensorBoard
doesn't need to be running when collecting the trace.

Only one trace may be collected at a time. A RuntimeError will be raised if
:func:`start_trace` is called while another trace is running.

Args:
log_dir: The directory to save the profiler trace to (usually the
TensorBoard log directory).
"""
with _profile_state.lock:
if _profile_state.profile_session is not None:
raise RuntimeError("Profile has already been started. "
"Only one profile may be run at a time.")

_profile_state.profile_session = torch_xla._XLAC.profiler.TslProfilerSessionWrapper(
)
_profile_state.log_dir = str(log_dir)


def stop_trace() -> None:
"""Stops the currently-running profiler trace.

The trace will be saved to the ``log_dir`` passed to the corresponding
:func:`start_trace` call. Raises a RuntimeError if a trace hasn't been started.
"""
with _profile_state.lock:
if _profile_state.profile_session is None:
raise RuntimeError("No profile started")
sess = _profile_state.profile_session
sess.export(sess.stop(), str(_profile_state.log_dir))
_profile_state.reset()
Loading