Skip to content

Commit

Permalink
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
andportnoy committed Dec 9, 2024
1 parent 12b45b3 commit cc22334
Show file tree
Hide file tree
Showing 8 changed files with 301 additions and 20 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/mosaic/gpu/examples/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ def benchmark_and_verify(
head_dim=head_dim,
**kwargs,
)
out, runtime = profiler.measure(f, q[0], k[0], v[0])
out, runtime = profiler.measure(f)(q[0], k[0], v[0])
out = out[None]

@jax.jit
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/mosaic/gpu/examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def verify(
wgmma_impl=WGMMADefaultImpl,
profiler_spec=prof_spec,
)
z, runtime = profiler.measure(f, x, y)
z, runtime = profiler.measure(f)(x, y)

if rhs_transpose:
dimension_numbers = ((1,), (1,)), ((), ())
Expand All @@ -382,7 +382,7 @@ def ref_f(x, y):
preferred_element_type=out_dtype,
).astype(out_dtype)

ref, ref_runtime = profiler.measure(ref_f, x, y)
ref, ref_runtime = profiler.measure(ref_f)(x, y)
np.testing.assert_allclose(
z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3
)
Expand Down Expand Up @@ -426,7 +426,7 @@ def ref_f(x, y):
f = build_kernel(
m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs
)
_, runtime = profiler.measure(f, x, y)
_, runtime = profiler.measure(f)(x, y)
except ValueError as e:
if "Mosaic GPU kernel exceeds available shared memory" not in str(e):
raise
Expand Down
98 changes: 86 additions & 12 deletions jax/experimental/mosaic/gpu/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,9 @@ def _event_elapsed(start_event, end_event):
)(start_event, end_event)


def measure(
def _measure_events(
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> tuple[T, float]:
"""Measures the time it takes to execute the function on the GPU.
Args:
f: The function to measure. It must accept at least one argument and return
at least one output to be measurable.
*args: The arguments to pass to ``f``.
**kwargs: The keyword arguments to pass to ``f``.
Returns:
The return value of ``f`` and the elapsed time in milliseconds.
"""
if not has_registrations:
raise RuntimeError(
"This function requires jaxlib >=0.4.36 with CUDA support."
Expand All @@ -109,6 +98,91 @@ def run(*args, **kwargs):
return outs, float(elapsed)


def _measure_cupti(f, aggregate):
def wrapper(*args, **kwargs):
mosaic_gpu_lib._mosaic_gpu_ext._cupti_init()
try:
results = jax.block_until_ready(jax.jit(f)(*args, **kwargs))
finally:
timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings()
if not timings:
return results, None
elif aggregate:
return results, sum(item[1] for item in timings)
else:
return results, timings
return wrapper


def measure(f: Callable, *, mode: str = "cupti", aggregate: bool = True
) -> Callable:
"""Sets up a function ``f`` for profiling on GPU.
``measure`` is a higher-order function that augments the argument ``f`` to
return GPU runtime in milliseconds, in addition to its proper outputs.
Args:
f: The function to measure. It must accept at least one argument and return
at least one output to be measurable.
mode: The mode of operation. Possible values are:
- "cupti", for CUPTI-based profiling.
- "events", for CUDA events-based profiling.
The two modes use different measurement methodologies and should not be
treated as interchangeable backends. See the Notes section for important
discussion.
aggregate: Whether to report an aggregate runtime. When ``False`` (only
supported by ``mode="cupti"``), the per-kernel timings are returned as a
list of tuples ``(<kernel name>, <runtime in ms>)``.
Returns:
A new function ``g`` that returns the measured GPU runtime as its last
additional output. Otherwise ``g`` accepts the same inputs and returns the
same outputs as ``f``.
Notes:
`CUPTI (CUDA Profiling Tools Interface)
<https://docs.nvidia.com/cupti/index.html>`_ is a high-accuracy,
high-precision profiling and tracing API, used in particular by Nsight
Systems and Nsight Compute. When using ``measure`` with ``mode="cupti"``,
device (GPU) execution runtimes are recorded for each kernel launched
during the execution of the function. In that mode, setting
``aggregate=True`` will sum the individual kernel runtimes to arrive at an
aggregate measurement. The "gaps" between the kernels when the device is
idle are not included in the aggregate.
The CUPTI API only allows a single "subscriber". This means that the
CUPTI-based profiler will fail when the program is run using tools that
make use of CUPTI, such as CUDA-GDB, Compute Sanitizer, Nsight Systems, or
Nsight Compute.
``mode="events"`` uses a different approach: a CUDA event is recorded
before and after the function ``f`` is executed. The reported runtime is
the time elapsed between the two events. In particular, included in the
measurement are:
- any potential "gaps" between the kernels when the device is idle
- any potential "gaps" between the "before" event and the start of the
first kernel, or between the end of the last kernel and the "after" event
In an attempt to minimize the second effect, internally the events-based
implementation may execute ``f`` more than once to "warm up" and exclude
compilation time from the measurement.
"""
match mode:
case "cupti":
return _measure_cupti(f, aggregate)
case "events":
if not aggregate:
raise ValueError(f"{aggregate=} is not supported with {mode=}")
def measure_events_wrapper(*args, **kwargs):
return _measure_events(f, *args, **kwargs)
return measure_events_wrapper
case _:
raise ValueError(f"Unrecognized profiler mode {mode}")


class ProfilerSpec:
ENTER = 0
EXIT = 1 << 31
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def main(unused_argv):
for block_kv in (256, 128, 64):
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
try:
out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v)
out, runtime_ms = profiler.measure(functools.partial(attention, config=config))(q, k, v)
if seq_len < 32768:
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
Expand Down
106 changes: 105 additions & 1 deletion jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <memory>
#include <cstddef>
#include <cstdint>
#include <new>
#include <stdexcept>
#include <string>
#include <tuple>
#include <vector>

#include "nanobind/nanobind.h"
#include "nanobind/stl/tuple.h"
#include "nanobind/stl/vector.h"
#include "absl/cleanup/cleanup.h"
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
Expand Down Expand Up @@ -118,6 +124,75 @@ XLA_FFI_Error* EventElapsed(XLA_FFI_CallFrame* call_frame) {
return kEventElapsed->Call(call_frame);
}

#define THROW(...) \
do { \
throw std::runtime_error( \
absl::StrCat("Mosaic GPU profiler error: ", __VA_ARGS__)); \
} while (0)

#define THROW_IF(expr, ...) \
do { \
if (expr) THROW(__VA_ARGS__); \
} while (0)

#define THROW_IF_CUPTI_ERROR(expr, ...) \
do { \
CUptiResult _result = (expr); \
if (_result != CUPTI_SUCCESS) { \
const char* s; \
cuptiGetErrorMessage(_result, &s); \
THROW(s, ": " __VA_OPT__(, ) __VA_ARGS__); \
} \
} while (0)

// CUPTI can only have one subscriber per process, so it's ok to make the
// profiler state global.
struct {
CUpti_SubscriberHandle subscriber;
std::vector<std::tuple<const char* /*kernel_name*/, double /*ms*/>> timings;
} profiler_state;

void callback_request(uint8_t** buffer, size_t* size, size_t* maxNumRecords) {
// 10 MiB buffer size is generous but somewhat arbitrary, it's at the upper
// bound of what's recommended in CUPTI documentation:
// https://docs.nvidia.com/cupti/main/main.html#cupti-callback-api:~:text=For%20typical%20workloads%2C%20it%E2%80%99s%20suggested%20to%20choose%20a%20size%20between%201%20and%2010%20MB.
const int buffer_size = 10 * (1 << 20);
// 8 byte alignment is specified in the official CUPTI code samples, see
// extras/CUPTI/samples/common/helper_cupti_activity.h in your CUDA
// installation.
*buffer = new (std::align_val_t(8)) uint8_t[buffer_size];
*size = buffer_size;
*maxNumRecords = 0;
}

void callback_complete(CUcontext context, uint32_t streamId,
uint8_t* buffer, size_t size, size_t validSize) {
// take ownership of the buffer once CUPTI is done using it
absl::Cleanup cleanup = [buffer]() {
operator delete[](buffer, std::align_val_t(8));
};
CUpti_Activity* record = nullptr;
while (true) {
CUptiResult status = cuptiActivityGetNextRecord(buffer, validSize, &record);
if (status == CUPTI_SUCCESS) {
if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) {
// TODO(andportnoy) handle multi-GPU
CUpti_ActivityKernel9* kernel = (CUpti_ActivityKernel9*)record;
// Convert integer nanoseconds to floating point milliseconds to match
// the interface of the events-based profiler.
double duration_ms = (kernel->end - kernel->start) / 1e6;
profiler_state.timings.push_back(
std::make_tuple(kernel->name, duration_ms));
}
} else if (status == CUPTI_ERROR_MAX_LIMIT_REACHED) {
// no more records available
break;
} else {
THROW_IF_CUPTI_ERROR(status);
}
}
}

NB_MODULE(_mosaic_gpu_ext, m) {
m.def("registrations", []() {
return nb::make_tuple(
Expand All @@ -139,6 +214,35 @@ NB_MODULE(_mosaic_gpu_ext, m) {
}
}
});
m.def("_cupti_init", []() {
profiler_state.timings.clear();
// Ok to pass nullptr for the callback here because we don't register any
// callbacks through cuptiEnableCallback.
auto subscribe_result = cuptiSubscribe(
&profiler_state.subscriber, /*callback=*/nullptr, /*userdata=*/nullptr);
if (subscribe_result == CUPTI_ERROR_MULTIPLE_SUBSCRIBERS_NOT_SUPPORTED) {
THROW(
"Attempted to subscribe to CUPTI while another subscriber, such as "
"Nsight Systems or Nsight Compute, is active. CUPTI backend of the "
"Mosaic GPU profiler cannot be used in that mode since CUPTI does "
"not support multiple subscribers.");
}
THROW_IF_CUPTI_ERROR(subscribe_result, "failed to subscribe to CUPTI");
THROW_IF_CUPTI_ERROR(
cuptiActivityRegisterCallbacks(callback_request, callback_complete),
"failed to register CUPTI activity callbacks");
THROW_IF_CUPTI_ERROR(
cuptiActivityEnable(CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL),
"failed to enable tracking of kernel activity by CUPTI");
});
m.def("_cupti_get_timings", []() {
THROW_IF_CUPTI_ERROR(cuptiUnsubscribe(profiler_state.subscriber),
"failed to unsubscribe from CUPTI");
THROW_IF_CUPTI_ERROR(cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_NONE),
"failed to flush CUPTI activity buffers");
THROW_IF_CUPTI_ERROR(cuptiFinalize(), "failed to detach CUPTI");
return profiler_state.timings;
});
}

} // namespace
Expand Down
14 changes: 14 additions & 0 deletions tests/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,17 @@ jax_multiplatform_test(
"//jax/experimental/mosaic/gpu/examples:flash_attention",
] + py_deps("absl/testing"),
)

jax_multiplatform_test(
name = "profiler_cupti_test",
srcs = ["profiler_cupti_test.py"],
enable_backends = [],
enable_configs = ["gpu_h100"],
deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing"),
tags = [
"noasan", # CUPTI leaks memory
"nomsan",
],
)
5 changes: 3 additions & 2 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1691,9 +1691,10 @@ def kernel(ctx, inp, out, smem):

class ProfilerTest(TestCase):

def test_measure(self):
def test_measure_events_explicit(self):
x = jnp.arange(1024 * 1024)
profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test
_, runtime_ms = profiler.measure(lambda x, y: x + y, mode="events")(x, x)
self.assertIsInstance(runtime_ms, float)

def test_profile(self):
def kernel(ctx, src, dst, _):
Expand Down
Loading

0 comments on commit cc22334

Please sign in to comment.