-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation #24805
[Mosaic GPU] Add CUPTI profiler alongside events-based implementation #24805
Conversation
38419c3
to
86ba114
Compare
775078a
to
061d3c1
Compare
8e5168c
to
754c58e
Compare
754c58e
to
a7d2562
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Outdated
// take ownership of the buffer once CUPTI is done using it | ||
std::unique_ptr<uint8_t> buffer = absl::WrapUnique(buffer_raw); | ||
CUpti_Activity* record = nullptr; | ||
for (;;) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: while (true)
is a little clearer imo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to while (true)
.
jaxlib/mosaic/gpu/BUILD
Outdated
@@ -188,6 +188,7 @@ pybind_extension( | |||
"@com_google_absl//absl/cleanup", | |||
"@com_google_absl//absl/strings", | |||
"@nanobind", | |||
"@xla//xla/pjrt:exceptions", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we try to avoid this dependency? IIUC the only reason why it's here is so that you can throw XlaRuntimeError
. But it would be perfectly ok to raise a different Python error. The dep right now causes some issues with the build internally
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced with std::runtime_error
.
(XlaRuntimeError
had the benefit that it respects JAX_TRACEBACK_FILTERING
).
d89fc98
to
e02cc0e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok this looks great, but was flagged by our internal ASAN harness for two reasons:
- If you use aligned
new[]
, you must also use the aligneddelete[]
operator (currently the code uses an unaligneddelete
which is doubly bad). This is fixed in the following diff (+ some missing C++ headers):
diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
--- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
+++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
@@ -13,9 +13,13 @@ See the License for the specific languag
limitations under the License.
==============================================================================*/
-#include <memory>
+#include <cstddef>
+#include <cstdint>
+#include <new>
#include <stdexcept>
#include <string>
+#include <tuple>
+#include <vector>
#include "nanobind/nanobind.h"
#include "nanobind/stl/tuple.h"
@@ -162,13 +166,14 @@ void callback_request(uint8_t** buffer,
}
void callback_complete(CUcontext context, uint32_t streamId,
- uint8_t* buffer_raw, size_t size, size_t validSize) {
+ uint8_t* buffer, size_t size, size_t validSize) {
// take ownership of the buffer once CUPTI is done using it
- std::unique_ptr<uint8_t> buffer = absl::WrapUnique(buffer_raw);
+ absl::Cleanup cleanup = [buffer]() {
+ operator delete[](buffer, std::align_val_t(8));
+ };
CUpti_Activity* record = nullptr;
while (true) {
- CUptiResult status =
- cuptiActivityGetNextRecord(buffer.get(), validSize, &record);
+ CUptiResult status = cuptiActivityGetNextRecord(buffer, validSize, &record);
if (status == CUPTI_SUCCESS) {
if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) {
// TODO(andportnoy) handle multi-GPU
- It looks like CUPTI loves to leak memory, which makes our ASAN test harness very sad. Please make sure to separate CUPTI tests into a separate test target where we can disable ASAN, because we want to continue to make sure we don't leak memory otherwise.
e02cc0e
to
7836d89
Compare
Thanks, I applied the patch and moved CUPTI-specific tests to a separate target which I marked with |
case "cupti": | ||
return _measure_cupti(f, aggregate) | ||
case "events": | ||
if aggregate == False: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: if not aggregate
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to if not aggregate
.
return wrapper | ||
|
||
|
||
def measure(f, mode="cupti", aggregate=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe make both optional parameters keyword-only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -69,7 +69,7 @@ def _event_elapsed(start_event, end_event): | |||
)(start_event, end_event) | |||
|
|||
|
|||
def measure( | |||
def _measure_events( | |||
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs | |||
) -> tuple[T, float]: | |||
"""Measures the time it takes to execute the function on the GPU. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring should be move to the public measure
function and extended to explain mode
and aggregate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a detailed docstring to measure
.
return wrapper | ||
|
||
|
||
def measure(f, mode="cupti", aggregate=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you mind adding type annotations to measure
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/mosaic/BUILD
Outdated
deps = [ | ||
"//jax:mosaic_gpu", | ||
] + py_deps("absl/testing"), | ||
tags = ["noasan"], # CUPTI leaks memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Err, is this a bug in CUPTI or are you referring to the static global you are adding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bug in CUPTI. Could you also please add nomsan
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# pylint: disable=g-complex-comprehension | ||
config.parse_flags_with_absl() | ||
|
||
class ProfilerCuptiTest(parameterized.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to do it in a follow up, but we probably need a single profiler test which would be parameterized over mode
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk the tests seem reasonable to me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two mode
s are fundamentally different in a pretty significant way (see measure
docstring), so it felt more natural to write dedicated test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix Sergei's comments too!
tests/mosaic/BUILD
Outdated
deps = [ | ||
"//jax:mosaic_gpu", | ||
] + py_deps("absl/testing"), | ||
tags = ["noasan"], # CUPTI leaks memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bug in CUPTI. Could you also please add nomsan
?
# pylint: disable=g-complex-comprehension | ||
config.parse_flags_with_absl() | ||
|
||
class ProfilerCuptiTest(parameterized.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idk the tests seem reasonable to me?
7836d89
to
5f580af
Compare
return wrapper | ||
|
||
|
||
def measure(f: Callable, *, mode: str = "cupti", aggregate: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI you can be a bit more precise with the types here:
def measure(f: Callable[P, T], ...) -> Callable[P, tuple[T, float]]: ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That return type would be incorrect when aggregate=False
. Wouldn't hurt to add -> Callable
return type though, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added -> Callable
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, if you really want it, you can define overloads for measure
with literal types for aggregate
, but I have mixed feelings about having different return types tbh as mentioned in the other thread.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have mixed feelings about having different return types
Why though? The default value of aggregate
is True
, which means for both modes the return types are the same. The user needs to consciously type out aggregate=False
to actually get that array of tuples instead of the default aggregate value.
def wrapper(*args, **kwargs): | ||
mosaic_gpu_lib._mosaic_gpu_ext._cupti_init() | ||
try: | ||
results = jax.block_until_ready(jax.jit(f)(*args, **kwargs)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we jit f
in the enclosing namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm what do you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant
jit_f = jax.jit(f)
just before the definition of wrapper
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why? If you are thinking about performance, jax.jit
doesn't actually JIT until the invocation anyway, right? And if the function has been compiled for the shapes and types before, then it's a quick cache look up anyway.
What am I missing?
timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings() | ||
if not timings: | ||
return results, None | ||
elif aggregate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it useful to have aggregate=False
? Can we assume its true so that the return type of measure
is the same for both modes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely useful, I'd rather keep it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user can always aggregate manually, it's just one comprehension away.
Curious what @apaszke thinks as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The user can always aggregate manually
This seems to suggest "let's make aggregate=False
the only option" because the user can aggregate themselves, but then the return types are going to be different between the two modes.
Can we assume its true so that the return type of measure is the same for both modes?
This seems to suggest making aggregate=True
the only option to make the return types uniform.
Am I missing something or these are contradictory?
I think it's valuable to be able to look at precise individual (disaggregated) kernel timings, this is a crucial bit of functionality that you can only get with CUPTI and not with events. Keeping it is more important than making the return types uniform.
We settled on this design (aggregate/summed timings by default with an option to see individual timings) with @apaszke over DMs over a month ago, but I should have posted more widely so we could have had this discussion earlier :)
5f580af
to
cc22334
Compare
No description provided.