diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 4728f00a9243..6e9ba6a382db 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -576,7 +576,7 @@ def benchmark_and_verify( head_dim=head_dim, **kwargs, ) - out, runtime = profiler.measure(f, q[0], k[0], v[0]) + out, runtime = profiler.measure(f)(q[0], k[0], v[0]) out = out[None] @jax.jit diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index 7aa96e7fa5d3..dd05c1e2ecc8 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -360,7 +360,7 @@ def verify( wgmma_impl=WGMMADefaultImpl, profiler_spec=prof_spec, ) - z, runtime = profiler.measure(f, x, y) + z, runtime = profiler.measure(f)(x, y) if rhs_transpose: dimension_numbers = ((1,), (1,)), ((), ()) @@ -382,7 +382,7 @@ def ref_f(x, y): preferred_element_type=out_dtype, ).astype(out_dtype) - ref, ref_runtime = profiler.measure(ref_f, x, y) + ref, ref_runtime = profiler.measure(ref_f)(x, y) np.testing.assert_allclose( z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3 ) @@ -426,7 +426,7 @@ def ref_f(x, y): f = build_kernel( m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs ) - _, runtime = profiler.measure(f, x, y) + _, runtime = profiler.measure(f)(x, y) except ValueError as e: if "Mosaic GPU kernel exceeds available shared memory" not in str(e): raise diff --git a/jax/experimental/mosaic/gpu/profiler.py b/jax/experimental/mosaic/gpu/profiler.py index e51a7b842931..c6bf7ed9dbe9 100644 --- a/jax/experimental/mosaic/gpu/profiler.py +++ b/jax/experimental/mosaic/gpu/profiler.py @@ -69,7 +69,7 @@ def _event_elapsed(start_event, end_event): )(start_event, end_event) -def measure( +def _measure_events( f: Callable[P, T], *args: P.args, **kwargs: P.kwargs ) -> tuple[T, float]: """Measures the time it takes to execute the function on the GPU. @@ -109,6 +109,36 @@ def run(*args, **kwargs): return outs, float(elapsed) +def _measure_cupti(f, aggregate): + def wrapper(*args, **kwargs): + mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() + try: + results = jax.block_until_ready(jax.jit(f)(*args, **kwargs)) + finally: + timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() + if not timings: + return results, None + elif aggregate: + return results, sum(item[1] for item in timings) + else: + return results, timings + return wrapper + + +def measure(f, mode="cupti", aggregate=True): + match mode: + case "cupti": + return _measure_cupti(f, aggregate) + case "events": + if aggregate == False: + raise ValueError(f"{aggregate=} is not supported with {mode=}") + def measure_events_wrapper(*args, **kwargs): + return _measure_events(f, *args, **kwargs) + return measure_events_wrapper + case _: + raise ValueError(f"Unrecognized profiler mode {mode}") + + class ProfilerSpec: ENTER = 0 EXIT = 1 << 31 diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 294ef153ff93..1f3bf408d4cb 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -274,7 +274,7 @@ def main(unused_argv): for block_kv in (256, 128, 64): config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2) try: - out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v) + out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v) if seq_len < 32768: out_ref = attention_reference(q, k, v) np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 608270239882..8f61da6b2404 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -18,6 +18,8 @@ limitations under the License. #include #include "nanobind/nanobind.h" +#include "nanobind/stl/tuple.h" +#include "nanobind/stl/vector.h" #include "absl/cleanup/cleanup.h" #include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" @@ -118,6 +120,74 @@ XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) { return kEventElapsed->Call(call_frame); } +#define THROW(...) \ + do { \ + throw std::runtime_error( \ + absl::StrCat("Mosaic GPU profiler error: ", __VA_ARGS__)); \ + } while (0) + +#define THROW_IF(expr, ...) \ + do { \ + if (expr) THROW(__VA_ARGS__); \ + } while (0) + +#define THROW_IF_CUPTI_ERROR(expr, ...) \ + do { \ + CUptiResult _result = (expr); \ + if (_result != CUPTI_SUCCESS) { \ + const char* s; \ + cuptiGetErrorMessage(_result, &s); \ + THROW(s, ": " __VA_OPT__(, ) __VA_ARGS__); \ + } \ + } while (0) + +// CUPTI can only have one subscriber per process, so it's ok to make the +// profiler state global. +struct { + CUpti_SubscriberHandle subscriber; + std::vector> timings; +} profiler_state; + +void callback_request(uint8_t** buffer, size_t* size, size_t* maxNumRecords) { + // 10 MiB buffer size is generous but somewhat arbitrary, it's at the upper + // bound of what's recommended in CUPTI documentation: + // https://docs.nvidia.com/cupti/main/main.html#cupti-callback-api:~:text=For%20typical%20workloads%2C%20it%E2%80%99s%20suggested%20to%20choose%20a%20size%20between%201%20and%2010%20MB. + const int buffer_size = 10 * (1 << 20); + // 8 byte alignment is specified in the official CUPTI code samples, see + // extras/CUPTI/samples/common/helper_cupti_activity.h in your CUDA + // installation. + *buffer = new (std::align_val_t(8)) uint8_t[buffer_size]; + *size = buffer_size; + *maxNumRecords = 0; +} + +void callback_complete(CUcontext context, uint32_t streamId, + uint8_t* buffer_raw, size_t size, size_t validSize) { + // take ownership of the buffer once CUPTI is done using it + std::unique_ptr buffer = absl::WrapUnique(buffer_raw); + CUpti_Activity* record = nullptr; + while (true) { + CUptiResult status = + cuptiActivityGetNextRecord(buffer.get(), validSize, &record); + if (status == CUPTI_SUCCESS) { + if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) { + // TODO(andportnoy) handle multi-GPU + CUpti_ActivityKernel9* kernel = (CUpti_ActivityKernel9*)record; + // Convert integer nanoseconds to floating point milliseconds to match + // the interface of the events-based profiler. + double duration_ms = (kernel->end - kernel->start) / 1e6; + profiler_state.timings.push_back( + std::make_tuple(kernel->name, duration_ms)); + } + } else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) { + // no more records available + break; + } else { + THROW_IF_CUPTI_ERROR(status); + } + } +} + NB_MODULE(_mosaic_gpu_ext, m) { m.def("registrations", []() { return nb::make_tuple( @@ -139,6 +209,35 @@ NB_MODULE(_mosaic_gpu_ext, m) { } } }); + m.def("_cupti_init", []() { + profiler_state.timings.clear(); + // Ok to pass nullptr for the callback here because we don't register any + // callbacks through cuptiEnableCallback. + auto subscribe_result = cuptiSubscribe( + &profiler_state.subscriber, /*callback=*/nullptr, /*userdata=*/nullptr); + if (subscribe_result == CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED) { + THROW( + "Attempted to subscribe to CUPTI while another subscriber, such as " + "Nsight Systems or Nsight Compute, is active. CUPTI backend of the " + "Mosaic GPU profiler cannot be used in that mode since CUPTI does " + "not support multiple subscribers."); + } + THROW_IF_CUPTI_ERROR(subscribe_result, "failed to subscribe to CUPTI"); + THROW_IF_CUPTI_ERROR( + cuptiActivityRegisterCallbacks(callback_request, callback_complete), + "failed to register CUPTI activity callbacks"); + THROW_IF_CUPTI_ERROR( + cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL), + "failed to enable tracking of kernel activity by CUPTI"); + }); + m.def("_cupti_get_timings", []() { + THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber), + "failed to unsubscribe from CUPTI"); + THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE), + "failed to flush CUPTI activity buffers"); + THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI"); + return profiler_state.timings; + }); } } // namespace diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 80c4048720a3..37561e62488e 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -1720,6 +1720,62 @@ def kernel(ctx, src, dst, _): jax.block_until_ready(f(xd)) +class ProfilerMeasureTest(TestCase): + + def setUp(self): + super().setUp() + self.x = jnp.arange(1024 * 1024) + self.f = lambda x: 2*x + + def test_measure(self): + _, runtime_ms = profiler.measure(self.f)(self.x) + self.assertIsInstance(runtime_ms, float) + + def test_measure_cupti_explicit(self): + _, runtime_ms = profiler.measure(self.f, mode="cupti")(self.x) + self.assertIsInstance(runtime_ms, float) + + def test_measure_events_explicit(self): + _, runtime_ms = profiler.measure(self.f, mode="events")(self.x) + self.assertIsInstance(runtime_ms, float) + + def test_measure_per_kernel(self): + _, runtimes_ms = profiler.measure(self.f, aggregate=False)(self.x) + for item in runtimes_ms: + self.assertIsInstance(item, tuple) + self.assertEqual(len(item), 2) + name, runtime_ms = item + self.assertIsInstance(name, str) + self.assertIsInstance(runtime_ms, float) + + def test_measure_cupti_repeated(self): + f_profiled = profiler.measure(self.f, mode="cupti") + n = 3 + timings = [f_profiled(self.x)[1] for _ in range(n)] + for item in timings: + self.assertIsInstance(item, float) + + def test_measure_repeated_interleaved(self): + # test that kernels run outside of measure() are not captured + _, timings = profiler.measure(self.f, mode='cupti', aggregate=False)(self.x) + self.assertEqual(len(timings), 1) + self.f(self.x) + _, timings = profiler.measure(self.f, mode='cupti', aggregate=False)(self.x) + self.assertEqual(len(timings), 1) + + def test_measure_double_subscription(self): + # This needs to run in a separate process, otherwise it affects the + # outcomes of other tests since CUPTI state is global. + self.skipTest("Must run in a separate process from other profiler tests") + # Initialize profiler manually, which subscribes to CUPTI. There can only + # be one CUPTI subscriber at a time. + jax._src.lib.mosaic_gpu._mosaic_gpu_ext._cupti_init() + with self.assertRaisesRegex(RuntimeError, + "Attempted to subscribe to CUPTI while another subscriber, " + "such as Nsight Systems or Nsight Compute, is active."): + profiler.measure(self.f, aggregate=False)(self.x) + + class TorchTest(TestCase): @classmethod