Skip to content

Commit

Permalink
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
andportnoy committed Dec 4, 2024
1 parent 12b45b3 commit a7d2562
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 6 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/examples/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def benchmark_and_verify(
head_dim=head_dim,
**kwargs,
)
out, runtime = profiler.measure(f, q[0], k[0], v[0])
out, runtime = profiler.measure(f)(q[0], k[0], v[0])
out = out[None]

@jax.jit
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/mosaic/gpu/examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def verify(
wgmma_impl=WGMMADefaultImpl,
profiler_spec=prof_spec,
)
z, runtime = profiler.measure(f, x, y)
z, runtime = profiler.measure(f)(x, y)

if rhs_transpose:
dimension_numbers = ((1,), (1,)), ((), ())
Expand All @@ -382,7 +382,7 @@ def ref_f(x, y):
preferred_element_type=out_dtype,
).astype(out_dtype)

ref, ref_runtime = profiler.measure(ref_f, x, y)
ref, ref_runtime = profiler.measure(ref_f)(x, y)
np.testing.assert_allclose(
z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
)
Expand Down Expand Up @@ -426,7 +426,7 @@ def ref_f(x, y):
f = build_kernel(
m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs
)
_, runtime = profiler.measure(f, x, y)
_, runtime = profiler.measure(f)(x, y)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
Expand Down
32 changes: 31 additions & 1 deletion jax/experimental/mosaic/gpu/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _event_elapsed(start_event, end_event):
)(start_event, end_event)


def measure(
def _measure_events(
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> tuple[T, float]:
"""Measures the time it takes to execute the function on the GPU.
Expand Down Expand Up @@ -109,6 +109,36 @@ def run(*args, **kwargs):
return outs, float(elapsed)


def _measure_cupti(f, aggregate):
def wrapper(*args, **kwargs):
mosaic_gpu_lib._mosaic_gpu_ext._cupti_init()
try:
results = jax.block_until_ready(jax.jit(f)(*args, **kwargs))
finally:
timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings()
if not timings:
return results, None
elif aggregate:
return results, sum(item[1] for item in timings)
else:
return results, timings
return wrapper


def measure(f, mode="cupti", aggregate=True):
match mode:
case "cupti":
return _measure_cupti(f, aggregate)
case "events":
if aggregate == False:
raise ValueError(f"{aggregate=} is not supported with {mode=}")
def measure_events_wrapper(*args, **kwargs):
return _measure_events(f, *args, **kwargs)
return measure_events_wrapper
case _:
raise ValueError(f"Unrecognized profiler mode {mode}")


class ProfilerSpec:
ENTER = 0
EXIT = 1 << 31
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def main(unused_argv):
for block_kv in (256, 128, 64):
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
try:
out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v)
out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v)
if seq_len < 32768:
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ pybind_extension(
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/pjrt:exceptions",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
"@xla//xla/tsl/cuda:cudart",
Expand Down
100 changes: 100 additions & 0 deletions jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ limitations under the License.
#include <string>

#include "nanobind/nanobind.h"
#include "nanobind/stl/tuple.h"
#include "nanobind/stl/vector.h"
#include "absl/cleanup/cleanup.h"
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
#include "xla/pjrt/exceptions.h"

namespace jax::cuda {
namespace {
Expand Down Expand Up @@ -118,6 +121,74 @@ XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) {
return kEventElapsed->Call(call_frame);
}

#define THROW(...) \
do { \
throw xla::XlaRuntimeError( \
absl::StrCat("Mosaic GPU profiler error: ", __VA_ARGS__)); \
} while (0)

#define THROW_IF(expr, ...) \
do { \
if (expr) THROW(__VA_ARGS__); \
} while (0)

#define THROW_IF_CUPTI_ERROR(expr, ...) \
do { \
CUptiResult _result = (expr); \
if (_result != CUPTI_SUCCESS) { \
const char* s; \
cuptiGetErrorMessage(_result, &s); \
THROW(s, ": " __VA_OPT__(, ) __VA_ARGS__); \
} \
} while (0)

// CUPTI can only have one subscriber per process, so it's ok to make the
// profiler state global.
struct {
CUpti_SubscriberHandle subscriber;
std::vector<std::tuple<const char* /*kernel_name*/, double /*ms*/>> timings;
} profiler_state;

void callback_request(uint8_t** buffer, size_t* size, size_t* maxNumRecords) {
// 10 MiB buffer size is generous but somewhat arbitrary, it's at the upper
// bound of what's recommended in CUPTI documentation:
// https://docs.nvidia.com/cupti/main/main.html#cupti-callback-api:~:text=For%20typical%20workloads%2C%20it%E2%80%99s%20suggested%20to%20choose%20a%20size%20between%201%20and%2010%20MB.
const int buffer_size = 10 * (1 << 20);
// 8 byte alignment is specified in the official CUPTI code samples, see
// extras/CUPTI/samples/common/helper_cupti_activity.h in your CUDA
// installation.
*buffer = new (std::align_val_t(8)) uint8_t[buffer_size];
*size = buffer_size;
*maxNumRecords = 0;
}

void callback_complete(CUcontext context, uint32_t streamId,
uint8_t* buffer_raw, size_t size, size_t validSize) {
// take ownership of the buffer once CUPTI is done using it
std::unique_ptr<uint8_t> buffer = absl::WrapUnique(buffer_raw);
CUpti_Activity* record = nullptr;
for (;;) {
CUptiResult status =
cuptiActivityGetNextRecord(buffer.get(), validSize, &record);
if (status == CUPTI_SUCCESS) {
if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) {
// TODO(andportnoy) handle multi-GPU
CUpti_ActivityKernel9* kernel = (CUpti_ActivityKernel9*)record;
// Convert integer nanoseconds to floating point milliseconds to match
// the interface of the events-based profiler.
double duration_ms = (kernel->end - kernel->start) / 1e6;
profiler_state.timings.push_back(
std::make_tuple(kernel->name, duration_ms));
}
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
// no more records available
break;
} else {
THROW_IF_CUPTI_ERROR(status);
}
}
}

NB_MODULE(_mosaic_gpu_ext, m) {
m.def("registrations", []() {
return nb::make_tuple(
Expand All @@ -139,6 +210,35 @@ NB_MODULE(_mosaic_gpu_ext, m) {
}
}
});
m.def("_cupti_init", []() {
profiler_state.timings.clear();
// Ok to pass nullptr for the callback here because we don't register any
// callbacks through cuptiEnableCallback.
auto subscribe_result = cuptiSubscribe(
&profiler_state.subscriber, /*callback=*/nullptr, /*userdata=*/nullptr);
if (subscribe_result == CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED) {
THROW(
"Attempted to subscribe to CUPTI while another subscriber, such as "
"Nsight Systems or Nsight Compute, is active. CUPTI backend of the "
"Mosaic GPU profiler cannot be used in that mode since CUPTI does "
"not support multiple subscribers.");
}
THROW_IF_CUPTI_ERROR(subscribe_result, "failed to subscribe to CUPTI");
THROW_IF_CUPTI_ERROR(
cuptiActivityRegisterCallbacks(callback_request, callback_complete),
"failed to register CUPTI activity callbacks");
THROW_IF_CUPTI_ERROR(
cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL),
"failed to enable tracking of kernel activity by CUPTI");
});
m.def("_cupti_get_timings", []() {
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
"failed to unsubscribe from CUPTI");
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
"failed to flush CUPTI activity buffers");
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
return profiler_state.timings;
});
}

} // namespace
Expand Down
55 changes: 55 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,61 @@ def kernel(ctx, src, dst, _):
jax.block_until_ready(f(xd))


class ProfilerMeasureTest(TestCase):

def setUp(self):
self.x = jnp.arange(1024 * 1024)
self.f = lambda x: 2*x

def test_measure(self):
_, runtime_ms = profiler.measure(self.f)(self.x)
self.assertIsInstance(runtime_ms, float)

def test_measure_cupti_explicit(self):
_, runtime_ms = profiler.measure(self.f, mode="cupti")(self.x)
self.assertIsInstance(runtime_ms, float)

def test_measure_events_explicit(self):
_, runtime_ms = profiler.measure(self.f, mode="events")(self.x)
self.assertIsInstance(runtime_ms, float)

def test_measure_per_kernel(self):
_, runtimes_ms = profiler.measure(self.f, aggregate=False)(self.x)
for item in runtimes_ms:
self.assertIsInstance(item, tuple)
self.assertEqual(len(item), 2)
name, runtime_ms = item
self.assertIsInstance(name, str)
self.assertIsInstance(runtime_ms, float)

def test_measure_cupti_repeated(self):
f_profiled = profiler.measure(self.f, mode="cupti")
n = 3
timings = [f_profiled(self.x)[1] for _ in range(n)]
for item in timings:
self.assertIsInstance(item, float)

def test_measure_repeated_interleaved(self):
# test that kernels run outside of measure() are not captured
_, timings = profiler.measure(self.f, mode='cupti', aggregate=False)(self.x)
self.assertEqual(len(timings), 1)
self.f(self.x)
_, timings = profiler.measure(self.f, mode='cupti', aggregate=False)(self.x)
self.assertEqual(len(timings), 1)

def test_measure_double_subscription(self):
# This needs to run in a separate process, otherwise it affects the
# outcomes of other tests since CUPTI state is global.
self.skipTest("Must run in a separate process from other profiler tests")
# Initialize profiler manually, which subscribes to CUPTI. There can only
# be one CUPTI subscriber at a time.
jax._src.lib.mosaic_gpu._mosaic_gpu_ext._cupti_init()
with self.assertRaisesRegex(RuntimeError,
"Attempted to subscribe to CUPTI while another subscriber, "
"such as Nsight Systems or Nsight Compute, is active."):
profiler.measure(self.f, aggregate=False)(self.x)


class TorchTest(TestCase):

@classmethod
Expand Down

0 comments on commit a7d2562

Please sign in to comment.