Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mosaic GPU] Add CUPTI profiler alongside events-based implementation #24805

Merged

Conversation

andportnoy
Copy link
Contributor

No description provided.

@andportnoy andportnoy force-pushed the aportnoy/mosaic-gpu-cupti-profiler branch from 38419c3 to 86ba114 Compare November 8, 2024 22:36
@andportnoy andportnoy marked this pull request as ready for review November 8, 2024 22:36
@andportnoy andportnoy force-pushed the aportnoy/mosaic-gpu-cupti-profiler branch 2 times, most recently from 775078a to 061d3c1 Compare November 12, 2024 20:31
@andportnoy andportnoy force-pushed the aportnoy/mosaic-gpu-cupti-profiler branch 2 times, most recently from 8e5168c to 754c58e Compare November 20, 2024 03:53
@andportnoy andportnoy force-pushed the aportnoy/mosaic-gpu-cupti-profiler branch from 754c58e to a7d2562 Compare December 4, 2024 16:19
Copy link
Collaborator

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

// take ownership of the buffer once CUPTI is done using it
std::unique_ptr<uint8_t> buffer = absl::WrapUnique(buffer_raw);
CUpti_Activity* record = nullptr;
for (;;) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: while (true) is a little clearer imo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to while (true).

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Dec 5, 2024
@@ -188,6 +188,7 @@ pybind_extension(
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/strings",
"@nanobind",
"@xla//xla/pjrt:exceptions",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we try to avoid this dependency? IIUC the only reason why it's here is so that you can throw XlaRuntimeError. But it would be perfectly ok to raise a different Python error. The dep right now causes some issues with the build internally

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaced with std::runtime_error.
(XlaRuntimeError had the benefit that it respects JAX_TRACEBACK_FILTERING).

@andportnoy andportnoy force-pushed the aportnoy/mosaic-gpu-cupti-profiler branch 2 times, most recently from d89fc98 to e02cc0e Compare December 5, 2024 19:42
@andportnoy andportnoy requested a review from apaszke December 5, 2024 19:46
Copy link
Collaborator

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this looks great, but was flagged by our internal ASAN harness for two reasons:

  1. If you use aligned new[], you must also use the aligned delete[] operator (currently the code uses an unaligned delete which is doubly bad). This is fixed in the following diff (+ some missing C++ headers):
diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
--- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
+++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
@@ -13,9 +13,13 @@ See the License for the specific languag
 limitations under the License.
 ==============================================================================*/
 
-#include <memory>
+#include <cstddef>
+#include <cstdint>
+#include <new>
 #include <stdexcept>
 #include <string>
+#include <tuple>
+#include <vector>
 
 #include "nanobind/nanobind.h"
 #include "nanobind/stl/tuple.h"
@@ -162,13 +166,14 @@ void callback_request(uint8_t** buffer, 
 }
 
 void callback_complete(CUcontext context, uint32_t streamId,
-                       uint8_t* buffer_raw, size_t size, size_t validSize) {
+                       uint8_t* buffer, size_t size, size_t validSize) {
   // take ownership of the buffer once CUPTI is done using it
-  std::unique_ptr<uint8_t> buffer = absl::WrapUnique(buffer_raw);
+  absl::Cleanup cleanup = [buffer]() {
+    operator delete[](buffer, std::align_val_t(8));
+  };
   CUpti_Activity* record = nullptr;
   while (true) {
-    CUptiResult status =
-        cuptiActivityGetNextRecord(buffer.get(), validSize, &record);
+    CUptiResult status = cuptiActivityGetNextRecord(buffer, validSize, &record);
     if (status == CUPTI_SUCCESS) {
       if (record->kind == CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL) {
         // TODO(andportnoy) handle multi-GPU
  1. It looks like CUPTI loves to leak memory, which makes our ASAN test harness very sad. Please make sure to separate CUPTI tests into a separate test target where we can disable ASAN, because we want to continue to make sure we don't leak memory otherwise.

@andportnoy andportnoy force-pushed the aportnoy/mosaic-gpu-cupti-profiler branch from e02cc0e to 7836d89 Compare December 6, 2024 20:47
@andportnoy andportnoy requested a review from apaszke December 6, 2024 20:48
@andportnoy
Copy link
Contributor Author

Thanks, I applied the patch and moved CUPTI-specific tests to a separate target which I marked with tags = ["noasan"].

case "cupti":
return _measure_cupti(f, aggregate)
case "events":
if aggregate == False:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: if not aggregate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to if not aggregate.

return wrapper


def measure(f, mode="cupti", aggregate=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make both optional parameters keyword-only?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -69,7 +69,7 @@ def _event_elapsed(start_event, end_event):
)(start_event, end_event)


def measure(
def _measure_events(
f: Callable[P, T], *args: P.args, **kwargs: P.kwargs
) -> tuple[T, float]:
"""Measures the time it takes to execute the function on the GPU.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring should be move to the public measure function and extended to explain mode and aggregate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a detailed docstring to measure.

return wrapper


def measure(f, mode="cupti", aggregate=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind adding type annotations to measure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing"),
tags = ["noasan"], # CUPTI leaks memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Err, is this a bug in CUPTI or are you referring to the static global you are adding?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug in CUPTI. Could you also please add nomsan?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

# pylint: disable=g-complex-comprehension
config.parse_flags_with_absl()

class ProfilerCuptiTest(parameterized.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to do it in a follow up, but we probably need a single profiler test which would be parameterized over mode.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk the tests seem reasonable to me?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two modes are fundamentally different in a pretty significant way (see measure docstring), so it felt more natural to write dedicated test cases.

Copy link
Collaborator

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix Sergei's comments too!

deps = [
"//jax:mosaic_gpu",
] + py_deps("absl/testing"),
tags = ["noasan"], # CUPTI leaks memory
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug in CUPTI. Could you also please add nomsan?

# pylint: disable=g-complex-comprehension
config.parse_flags_with_absl()

class ProfilerCuptiTest(parameterized.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk the tests seem reasonable to me?

return wrapper


def measure(f: Callable, *, mode: str = "cupti", aggregate: bool = True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI you can be a bit more precise with the types here:

def measure(f: Callable[P, T], ...) -> Callable[P, tuple[T, float]]: ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That return type would be incorrect when aggregate=False. Wouldn't hurt to add -> Callable return type though, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added -> Callable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, if you really want it, you can define overloads for measure with literal types for aggregate, but I have mixed feelings about having different return types tbh as mentioned in the other thread.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have mixed feelings about having different return types

Why though? The default value of aggregate is True, which means for both modes the return types are the same. The user needs to consciously type out aggregate=False to actually get that array of tuples instead of the default aggregate value.

def wrapper(*args, **kwargs):
mosaic_gpu_lib._mosaic_gpu_ext._cupti_init()
try:
results = jax.block_until_ready(jax.jit(f)(*args, **kwargs))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we jit f in the enclosing namespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm what do you mean?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant

jit_f = jax.jit(f)

just before the definition of wrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? If you are thinking about performance, jax.jit doesn't actually JIT until the invocation anyway, right? And if the function has been compiled for the shapes and types before, then it's a quick cache look up anyway.

What am I missing?

timings = mosaic_gpu_lib._mosaic_gpu_ext._cupti_get_timings()
if not timings:
return results, None
elif aggregate:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it useful to have aggregate=False? Can we assume its true so that the return type of measure is the same for both modes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely useful, I'd rather keep it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user can always aggregate manually, it's just one comprehension away.

Curious what @apaszke thinks as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The user can always aggregate manually

This seems to suggest "let's make aggregate=False the only option" because the user can aggregate themselves, but then the return types are going to be different between the two modes.

Can we assume its true so that the return type of measure is the same for both modes?

This seems to suggest making aggregate=True the only option to make the return types uniform.

Am I missing something or these are contradictory?

I think it's valuable to be able to look at precise individual (disaggregated) kernel timings, this is a crucial bit of functionality that you can only get with CUPTI and not with events. Keeping it is more important than making the return types uniform.

We settled on this design (aggregate/summed timings by default with an option to see individual timings) with @apaszke over DMs over a month ago, but I should have posted more widely so we could have had this discussion earlier :)

@andportnoy andportnoy force-pushed the aportnoy/mosaic-gpu-cupti-profiler branch from 5f580af to cc22334 Compare December 9, 2024 19:31
@copybara-service copybara-service bot merged commit 0d7eaeb into jax-ml:main Dec 11, 2024
11 of 12 checks passed
@andportnoy andportnoy deleted the aportnoy/mosaic-gpu-cupti-profiler branch December 11, 2024 14:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants