Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow users to use their own cuda contexts and streams in JIT mode #6345

Merged
merged 10 commits into from
Oct 28, 2021
116 changes: 92 additions & 24 deletions src/JITModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,63 +506,72 @@ void merge_handlers(JITHandlers &base, const JITHandlers &addins) {
if (addins.custom_get_library_symbol) {
base.custom_get_library_symbol = addins.custom_get_library_symbol;
}
if (addins.custom_cuda_acquire_context) {
base.custom_cuda_acquire_context = addins.custom_cuda_acquire_context;
}
if (addins.custom_cuda_release_context) {
base.custom_cuda_release_context = addins.custom_cuda_release_context;
}
if (addins.custom_cuda_get_stream) {
base.custom_cuda_get_stream = addins.custom_cuda_get_stream;
}
}

void print_handler(JITUserContext *context, const char *msg) {
if (context) {
(*context->handlers.custom_print)(context, msg);
if (context && context->handlers.custom_print) {
context->handlers.custom_print(context, msg);
} else {
return (*active_handlers.custom_print)(context, msg);
return active_handlers.custom_print(context, msg);
}
}

void *malloc_handler(JITUserContext *context, size_t x) {
if (context) {
return (*context->handlers.custom_malloc)(context, x);
if (context && context->handlers.custom_malloc) {
return context->handlers.custom_malloc(context, x);
} else {
return (*active_handlers.custom_malloc)(context, x);
return active_handlers.custom_malloc(context, x);
}
}

void free_handler(JITUserContext *context, void *ptr) {
if (context) {
(*context->handlers.custom_free)(context, ptr);
if (context && context->handlers.custom_free) {
context->handlers.custom_free(context, ptr);
} else {
(*active_handlers.custom_free)(context, ptr);
active_handlers.custom_free(context, ptr);
}
}

int do_task_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *), int idx,
uint8_t *closure) {
if (context) {
return (*context->handlers.custom_do_task)(context, f, idx, closure);
if (context && context->handlers.custom_do_task) {
return context->handlers.custom_do_task(context, f, idx, closure);
} else {
return (*active_handlers.custom_do_task)(context, f, idx, closure);
return active_handlers.custom_do_task(context, f, idx, closure);
}
}

int do_par_for_handler(JITUserContext *context, int (*f)(JITUserContext *, int, uint8_t *),
int min, int size, uint8_t *closure) {
if (context) {
return (*context->handlers.custom_do_par_for)(context, f, min, size, closure);
if (context && context->handlers.custom_do_par_for) {
return context->handlers.custom_do_par_for(context, f, min, size, closure);
} else {
return (*active_handlers.custom_do_par_for)(context, f, min, size, closure);
return active_handlers.custom_do_par_for(context, f, min, size, closure);
}
}

void error_handler_handler(JITUserContext *context, const char *msg) {
if (context) {
(*context->handlers.custom_error)(context, msg);
if (context && context->handlers.custom_error) {
context->handlers.custom_error(context, msg);
} else {
(*active_handlers.custom_error)(context, msg);
active_handlers.custom_error(context, msg);
}
}

int32_t trace_handler(JITUserContext *context, const halide_trace_event_t *e) {
if (context) {
return (*context->handlers.custom_trace)(context, e);
if (context && context->handlers.custom_trace) {
return context->handlers.custom_trace(context, e);
} else {
return (*active_handlers.custom_trace)(context, e);
return active_handlers.custom_trace(context, e);
}
}

Expand All @@ -578,6 +587,30 @@ void *get_library_symbol_handler(void *lib, const char *name) {
return (*active_handlers.custom_get_library_symbol)(lib, name);
}

int cuda_acquire_context_handler(JITUserContext *context, void **cuda_context_ptr, bool create) {
if (context && context->handlers.custom_cuda_acquire_context) {
return context->handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create);
} else {
return active_handlers.custom_cuda_acquire_context(context, cuda_context_ptr, create);
}
}

int cuda_release_context_handler(JITUserContext *context) {
if (context && context->handlers.custom_cuda_release_context) {
return context->handlers.custom_cuda_release_context(context);
} else {
return active_handlers.custom_cuda_release_context(context);
}
}

int cuda_get_stream_handler(JITUserContext *context, void *cuda_context, void **cuda_stream_ptr) {
if (context && context->handlers.custom_cuda_get_stream) {
return context->handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr);
} else {
return active_handlers.custom_cuda_get_stream(context, cuda_context, cuda_stream_ptr);
}
}

template<typename function_t>
function_t hook_function(const std::map<std::string, JITModule::Symbol> &exports, const char *hook_name, function_t hook) {
auto iter = exports.find(hook_name);
Expand Down Expand Up @@ -776,13 +809,13 @@ JITModule &make_module(llvm::Module *for_module, Target target,
hook_function(runtime.exports(), "halide_set_custom_trace", trace_handler);

runtime_internal_handlers.custom_get_symbol =
hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_symbol", get_symbol_handler);
hook_function(runtime.exports(), "halide_set_custom_get_symbol", get_symbol_handler);

runtime_internal_handlers.custom_load_library =
hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_load_library", load_library_handler);
hook_function(runtime.exports(), "halide_set_custom_load_library", load_library_handler);

runtime_internal_handlers.custom_get_library_symbol =
hook_function(shared_runtimes(MainShared).exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler);
hook_function(runtime.exports(), "halide_set_custom_get_library_symbol", get_library_symbol_handler);

active_handlers = runtime_internal_handlers;
merge_handlers(active_handlers, default_handlers);
Expand All @@ -794,6 +827,41 @@ JITModule &make_module(llvm::Module *for_module, Target target,
runtime.jit_module->name = "MainShared";
} else {
runtime.jit_module->name = "GPU";

// There are two versions of these cuda context
// management handlers we could use - one in the cuda
// module, and one in the cuda-debug module. If both
// modules are in use, we'll just want to use one of
// them, so that we don't needlessly create two cuda
// contexts. We'll use whichever was first
// created. The second one will then declare a
// dependency on the first one, to make sure things
// are destroyed in the correct order.

if (runtime_kind == CUDA || runtime_kind == CUDADebug) {
if (!runtime_internal_handlers.custom_cuda_acquire_context) {
// Neither module has been created.
runtime_internal_handlers.custom_cuda_acquire_context =
hook_function(runtime.exports(), "halide_set_cuda_acquire_context", cuda_acquire_context_handler);

runtime_internal_handlers.custom_cuda_release_context =
hook_function(runtime.exports(), "halide_set_cuda_release_context", cuda_release_context_handler);

runtime_internal_handlers.custom_cuda_get_stream =
hook_function(runtime.exports(), "halide_set_cuda_get_stream", cuda_get_stream_handler);

active_handlers = runtime_internal_handlers;
merge_handlers(active_handlers, default_handlers);
} else if (runtime_kind == CUDA) {
// The CUDADebug module has already been created.
// Use the context in the CUDADebug module and add
// a dependence edge from the CUDA module to it.
runtime.add_dependency(shared_runtimes(CUDADebug));
} else {
// The CUDA module has already been created.
runtime.add_dependency(shared_runtimes(CUDA));
}
}
}

uint64_t arg_addr =
Expand Down
16 changes: 16 additions & 0 deletions src/JITModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,22 @@ struct JITHandlers {
* an opened library. Equivalent to dlsym. Takes a handle
* returned by custom_load_library as the first argument. */
void *(*custom_get_library_symbol)(void *lib, const char *name){nullptr};

/** A custom method for the Halide runtime acquire a cuda
* context. The cuda context is treated as a void * to avoid a
* dependence on the cuda headers. If the create argument is set
* to true, a context should be created if one does not already
* exist. */
int32_t (*custom_cuda_acquire_context)(JITUserContext *user_context, void **cuda_context_ptr, bool create){nullptr};

/** The Halide runtime calls this when it is done with a cuda
* context. The default implementation does nothing. */
int32_t (*custom_cuda_release_context)(JITUserContext *user_context){nullptr};

/** A custom method for the Halide runtime to acquire a cuda
* stream to use. The cuda context and stream are both modelled
* as a void *, to avoid a dependence on the cuda headers. */
int32_t (*custom_cuda_get_stream)(JITUserContext *user_context, void *cuda_context, void **stream_ptr){nullptr};
};

namespace Internal {
Expand Down
2 changes: 1 addition & 1 deletion src/Pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ Realization Pipeline::realize(JITUserContext *context,
if (needs_crop) {
r[i].crop(crop);
}
r[i].copy_to_host();
r[i].copy_to_host(context);
}
return r;
}
Expand Down
17 changes: 17 additions & 0 deletions src/runtime/HalideRuntimeCuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,23 @@ extern uintptr_t halide_cuda_get_device_ptr(void *user_context, struct halide_bu
* driver. See halide_reuse_device_allocations. */
extern int halide_cuda_release_unused_device_allocations(void *user_context);

// These typedefs treat both a CUcontext and a CUstream as a void *,
// to avoid dependencies on cuda headers.
typedef int (*halide_cuda_acquire_context_t)(void *, // user_context
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Completely orthogonal to this PR, but IMHO we should consider migrating typedef bar foo; to using foo = bar; as I think it reads easier and is easier to search for)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this header is supposed to compile in C mode.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and/or with janky legacy toolchains

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhhh right

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Do we actually compile/run any tests in plain-C mode? If not, we should add one)

void **, // cuda context out parameter
bool); // should create a context if none exist
typedef int (*halide_cuda_release_context_t)(void * /* user_context */);
typedef int (*halide_cuda_get_stream_t)(void *, // user_context
void *, // context
void **); // stream out parameter

/** Set custom methods to acquire and release cuda contexts and streams */
// @{
extern halide_cuda_acquire_context_t halide_set_cuda_acquire_context(halide_cuda_acquire_context_t handler);
extern halide_cuda_release_context_t halide_set_cuda_release_context(halide_cuda_release_context_t handler);
extern halide_cuda_get_stream_t halide_set_cuda_get_stream(halide_cuda_get_stream_t handler);
// @}

#ifdef __cplusplus
} // End extern "C"
#endif
Expand Down
86 changes: 77 additions & 9 deletions src/runtime/cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ extern "C" {
// - A call to halide_cuda_acquire_context is followed by a matching call to
// halide_cuda_release_context. halide_cuda_acquire_context should block while a
// previous call (if any) has not yet been released via halide_cuda_release_context.
WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
WEAK int halide_default_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
// TODO: Should we use a more "assertive" assert? these asserts do
// not block execution on failure.
halide_assert(user_context, ctx != nullptr);
Expand Down Expand Up @@ -179,7 +179,7 @@ WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool cr
return 0;
}

WEAK int halide_cuda_release_context(void *user_context) {
WEAK int halide_default_cuda_release_context(void *user_context) {
return 0;
}

Expand All @@ -188,7 +188,7 @@ WEAK int halide_cuda_release_context(void *user_context) {
// for the context (nullptr stream). The context is passed in for convenience, but
// any sort of scoping must be handled by that of the
// halide_cuda_acquire_context/halide_cuda_release_context pair, not this call.
WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
WEAK int halide_default_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
// There are two default streams we could use. stream 0 is fully
// synchronous. stream 2 gives a separate non-blocking stream per
// thread.
Expand All @@ -198,6 +198,53 @@ WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *str

} // extern "C"

namespace Halide {
namespace Runtime {
namespace Internal {
namespace CUDA {

WEAK halide_cuda_acquire_context_t acquire_context = (halide_cuda_acquire_context_t)halide_default_cuda_acquire_context;
WEAK halide_cuda_release_context_t release_context = (halide_cuda_release_context_t)halide_default_cuda_release_context;
WEAK halide_cuda_get_stream_t get_stream = (halide_cuda_get_stream_t)halide_default_cuda_get_stream;

} // namespace CUDA
} // namespace Internal
} // namespace Runtime
} // namespace Halide

extern "C" {

WEAK int halide_cuda_acquire_context(void *user_context, CUcontext *ctx, bool create = true) {
return CUDA::acquire_context(user_context, (void **)ctx, create);
}

WEAK halide_cuda_acquire_context_t halide_set_cuda_acquire_context(halide_cuda_acquire_context_t handler) {
halide_cuda_acquire_context_t result = CUDA::acquire_context;
CUDA::acquire_context = handler;
return result;
}

WEAK int halide_cuda_release_context(void *user_context) {
return CUDA::release_context(user_context);
}

WEAK halide_cuda_release_context_t halide_set_cuda_release_context(halide_cuda_release_context_t handler) {
halide_cuda_release_context_t result = CUDA::release_context;
CUDA::release_context = handler;
return result;
}

WEAK int halide_cuda_get_stream(void *user_context, CUcontext ctx, CUstream *stream) {
return CUDA::get_stream(user_context, (void *)ctx, (void **)stream);
}

WEAK halide_cuda_get_stream_t halide_set_cuda_get_stream(halide_cuda_get_stream_t handler) {
halide_cuda_get_stream_t result = CUDA::get_stream;
CUDA::get_stream = handler;
return result;
}
}

namespace Halide {
namespace Runtime {
namespace Internal {
Expand Down Expand Up @@ -845,7 +892,8 @@ WEAK int halide_cuda_device_malloc(void *user_context, halide_buffer_t *buf) {

namespace {
WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c,
uint64_t src, uint64_t dst, int d, bool from_host, bool to_host) {
uint64_t src, uint64_t dst, int d, bool from_host, bool to_host,
CUstream stream) {
if (d > MAX_COPY_DIMS) {
error(user_context) << "Buffer has too many dimensions to copy to/from GPU\n";
return -1;
Expand All @@ -858,15 +906,27 @@ WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c,
if (!from_host && to_host) {
debug(user_context) << "cuMemcpyDtoH(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n";
copy_name = "cuMemcpyDtoH";
err = cuMemcpyDtoH((void *)dst, (CUdeviceptr)src, c.chunk_size);
if (stream) {
err = cuMemcpyDtoHAsync((void *)dst, (CUdeviceptr)src, c.chunk_size, stream);
} else {
err = cuMemcpyDtoH((void *)dst, (CUdeviceptr)src, c.chunk_size);
}
} else if (from_host && !to_host) {
debug(user_context) << "cuMemcpyHtoD(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n";
copy_name = "cuMemcpyHtoD";
err = cuMemcpyHtoD((CUdeviceptr)dst, (void *)src, c.chunk_size);
if (stream) {
err = cuMemcpyHtoDAsync((CUdeviceptr)dst, (void *)src, c.chunk_size, stream);
} else {
err = cuMemcpyHtoD((CUdeviceptr)dst, (void *)src, c.chunk_size);
}
} else if (!from_host && !to_host) {
debug(user_context) << "cuMemcpyDtoD(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n";
copy_name = "cuMemcpyDtoD";
err = cuMemcpyDtoD((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size);
if (stream) {
err = cuMemcpyDtoDAsync((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size, stream);
} else {
err = cuMemcpyDtoD((CUdeviceptr)dst, (CUdeviceptr)src, c.chunk_size);
}
} else if (dst != src) {
debug(user_context) << "memcpy(" << (void *)dst << ", " << (void *)src << ", " << c.chunk_size << ")\n";
// Could reach here if a user called directly into the
Expand All @@ -881,7 +941,7 @@ WEAK int cuda_do_multidimensional_copy(void *user_context, const device_copy &c,
} else {
ssize_t src_off = 0, dst_off = 0;
for (int i = 0; i < (int)c.extent[d - 1]; i++) {
int err = cuda_do_multidimensional_copy(user_context, c, src + src_off, dst + dst_off, d - 1, from_host, to_host);
int err = cuda_do_multidimensional_copy(user_context, c, src + src_off, dst + dst_off, d - 1, from_host, to_host, stream);
dst_off += c.dst_stride_bytes[d - 1];
src_off += c.src_stride_bytes[d - 1];
if (err) {
Expand Down Expand Up @@ -938,7 +998,15 @@ WEAK int halide_cuda_buffer_copy(void *user_context, struct halide_buffer_t *src
}
#endif

err = cuda_do_multidimensional_copy(user_context, c, c.src + c.src_begin, c.dst, dst->dimensions, from_host, to_host);
CUstream stream = nullptr;
if (cuStreamSynchronize != nullptr) {
int result = halide_cuda_get_stream(user_context, ctx.context, &stream);
if (result != 0) {
error(user_context) << "CUDA: In cuda_do_multidimensional_copy, halide_cuda_get_stream returned " << result << "\n";
}
}

err = cuda_do_multidimensional_copy(user_context, c, c.src + c.src_begin, c.dst, dst->dimensions, from_host, to_host, stream);

#ifdef DEBUG_RUNTIME
uint64_t t_after = halide_current_time_ns(user_context);
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/cuda_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ CUDA_FN_3020(CUresult, cuMemFree, cuMemFree_v2, (CUdeviceptr dptr));
CUDA_FN_3020(CUresult, cuMemcpyHtoD, cuMemcpyHtoD_v2, (CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount));
CUDA_FN_3020(CUresult, cuMemcpyDtoH, cuMemcpyDtoH_v2, (void *dstHost, CUdeviceptr srcDevice, size_t ByteCount));
CUDA_FN_3020(CUresult, cuMemcpyDtoD, cuMemcpyDtoD_v2, (CUdeviceptr dstHost, CUdeviceptr srcDevice, size_t ByteCount));

CUDA_FN_3020(CUresult, cuMemcpyHtoDAsync, cuMemcpyHtoDAsync_v2, (CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream stream));
CUDA_FN_3020(CUresult, cuMemcpyDtoHAsync, cuMemcpyDtoHAsync_v2, (void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream stream));
CUDA_FN_3020(CUresult, cuMemcpyDtoDAsync, cuMemcpyDtoDAsync_v2, (CUdeviceptr dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream stream));

CUDA_FN_3020(CUresult, cuMemcpy3D, cuMemcpy3D_v2, (const CUDA_MEMCPY3D *pCopy));
CUDA_FN(CUresult, cuLaunchKernel, (CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra));
CUDA_FN(CUresult, cuCtxSynchronize, ());
Expand Down
Loading