Skip to content

[CUDA][HIP] Add Event Caching #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions source/adapters/cuda/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,34 @@ ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t Context,
urContextRetain(Context);
}

void ur_event_handle_t_::reset() {
detail::ur::assertion(
RefCount == 0, "Attempting to reset an event that is still referenced");

HasBeenWaitedOn = false;
IsRecorded = false;
IsStarted = false;
Queue = nullptr;
Context = nullptr;
}

ur_event_handle_t_::~ur_event_handle_t_() {
if (Queue != nullptr) {
if (HasOwnership) {
if (EvEnd)
UR_CHECK_ERROR(cuEventDestroy(EvEnd));

if (EvQueued)
UR_CHECK_ERROR(cuEventDestroy(EvQueued));

if (EvStart)
UR_CHECK_ERROR(cuEventDestroy(EvStart));
}
if (Queue) {
urQueueRelease(Queue);
}
urContextRelease(Context);
if (Context) {
urContextRelease(Context);
}
}

ur_result_t ur_event_handle_t_::start() {
Expand Down Expand Up @@ -141,22 +164,6 @@ ur_result_t ur_event_handle_t_::wait() {
return Result;
}

ur_result_t ur_event_handle_t_::release() {
if (!backendHasOwnership())
return UR_RESULT_SUCCESS;

assert(Queue != nullptr);

UR_CHECK_ERROR(cuEventDestroy(EvEnd));

if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) {
UR_CHECK_ERROR(cuEventDestroy(EvQueued));
UR_CHECK_ERROR(cuEventDestroy(EvStart));
}

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent,
ur_event_info_t propName,
size_t propValueSize,
Expand Down Expand Up @@ -254,16 +261,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
// decrement ref count. If it is 0, delete the event.
if (hEvent->decrementReferenceCount() == 0) {
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
try {
ScopedContext Active(hEvent->getContext());
Result = hEvent->release();
} catch (...) {
Result = UR_RESULT_ERROR_OUT_OF_RESOURCES;
if (!hEvent->backendHasOwnership()) {
return UR_RESULT_SUCCESS;
} else {
auto Queue = event_ptr->getQueue();
auto Context = event_ptr->getContext();

event_ptr->reset();
if (Queue) {
Queue->cache_event(event_ptr.release());
urQueueRelease(Queue);
}
urContextRelease(Context);
}
} catch (ur_result_t Err) {
return Err;
}
return Result;
}

return UR_RESULT_SUCCESS;
}

Expand Down
19 changes: 18 additions & 1 deletion source/adapters/cuda/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,21 @@ struct ur_event_handle_t_ {
const bool RequiresTimings =
Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE ||
Type == UR_COMMAND_TIMESTAMP_RECORDING_EXP;
if (Queue->has_cached_events()) {
auto retEvent = Queue->get_cached_event();

retEvent->Stream = Stream;
retEvent->StreamToken = StreamToken;
retEvent->CommandType = Type;
retEvent->Queue = Queue;
retEvent->Context = Queue->Context;
retEvent->RefCount = 1;

urQueueRetain(retEvent->Queue);
urContextRetain(retEvent->Context);

return retEvent;
}
native_type EvEnd = nullptr, EvQueued = nullptr, EvStart = nullptr;
UR_CHECK_ERROR(cuEventCreate(
&EvEnd, RequiresTimings ? CU_EVENT_DEFAULT : CU_EVENT_DISABLE_TIMING));
Expand All @@ -107,7 +122,9 @@ struct ur_event_handle_t_ {
return new ur_event_handle_t_(context, eventNative);
}

ur_result_t release();
// Resets attributes of an event.
// Throws an error if its RefCount is not 0.
void reset();

~ur_event_handle_t_();

Expand Down
11 changes: 11 additions & 0 deletions source/adapters/cuda/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded(
}
}

ur_queue_handle_t_::~ur_queue_handle_t_() {
urContextRelease(Context);
urDeviceRelease(Device);

std::lock_guard<std::mutex> CacheGuard(CacheMutex);
while (!CachedEvents.empty()) {
std::unique_ptr<ur_event_handle_t_> Ev{CachedEvents.top()};
CachedEvents.pop();
}
}

CUstream ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
uint32_t StreamI;
uint32_t Token;
Expand Down
30 changes: 26 additions & 4 deletions source/adapters/cuda/queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <algorithm>
#include <cuda.h>
#include <stack>
#include <vector>

using ur_stream_guard_ = std::unique_lock<std::mutex>;
Expand All @@ -35,6 +36,9 @@ struct ur_queue_handle_t_ {
// keep track of which streams have applied barrier
std::vector<bool> ComputeAppliedBarrier;
std::vector<bool> TransferAppliedBarrier;
// CachedEvents assumes ownership of events.
// Events on the stack are destructed when queue is destructed as well.
std::stack<ur_event_handle_t> CachedEvents;
ur_context_handle_t_ *Context;
ur_device_handle_t_ *Device;
CUevent BarrierEvent = nullptr;
Expand All @@ -57,6 +61,8 @@ struct ur_queue_handle_t_ {
std::mutex ComputeStreamMutex;
std::mutex TransferStreamMutex;
std::mutex BarrierMutex;
// The event cache might be accessed in multiple threads.
std::mutex CacheMutex;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be inclined to have this close in code to the CachedEvents

bool HasOwnership;

ur_queue_handle_t_(std::vector<CUstream> &&ComputeStreams,
Expand All @@ -77,10 +83,7 @@ struct ur_queue_handle_t_ {
urDeviceRetain(Device);
}

~ur_queue_handle_t_() {
urContextRelease(Context);
urDeviceRelease(Device);
}
~ur_queue_handle_t_();

void computeStreamWaitForBarrierIfNeeded(CUstream Strean, uint32_t StreamI);
void transferStreamWaitForBarrierIfNeeded(CUstream Stream, uint32_t StreamI);
Expand Down Expand Up @@ -245,4 +248,23 @@ struct ur_queue_handle_t_ {
uint32_t getNextEventID() noexcept { return ++EventCount; }

bool backendHasOwnership() const noexcept { return HasOwnership; }

bool has_cached_events() {
std::lock_guard<std::mutex> CacheGuard(CacheMutex);
return !CachedEvents.empty();
}

void cache_event(ur_event_handle_t Event) {
std::lock_guard<std::mutex> CacheGuard(CacheMutex);
CachedEvents.push(Event);
}

// Returns and removes an event from the CachedEvents stack.
ur_event_handle_t get_cached_event() {
std::lock_guard<std::mutex> CacheGuard(CacheMutex);
assert(!CachedEvents.empty());
Copy link
Author

@MartinWehking MartinWehking May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@konradkusiak97 spotted this. Need to change this to detail::ur::assertion

auto RetEv = CachedEvents.top();
CachedEvents.pop();
return RetEv;
}
};
63 changes: 40 additions & 23 deletions source/adapters/hip/event.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,34 @@ ur_event_handle_t_::ur_event_handle_t_(ur_context_handle_t Context,
urContextRetain(Context);
}

void ur_event_handle_t_::reset() {
detail::ur::assertion(
RefCount == 0, "Attempting to reset an event that is still referenced");

HasBeenWaitedOn = false;
IsRecorded = false;
IsStarted = false;
Queue = nullptr;
Context = nullptr;
}

ur_event_handle_t_::~ur_event_handle_t_() {
if (Queue != nullptr) {
if (HasOwnership) {
if (EvEnd)
UR_CHECK_ERROR(hipEventDestroy(EvEnd));

if (EvQueued)
UR_CHECK_ERROR(hipEventDestroy(EvQueued));

if (EvStart)
UR_CHECK_ERROR(hipEventDestroy(EvStart));
}
if (Queue) {
urQueueRelease(Queue);
}
urContextRelease(Context);
if (Context) {
urContextRelease(Context);
}
}

ur_result_t ur_event_handle_t_::start() {
Expand Down Expand Up @@ -171,21 +194,6 @@ ur_result_t ur_event_handle_t_::wait() {
return Result;
}

ur_result_t ur_event_handle_t_::release() {
if (!backendHasOwnership())
return UR_RESULT_SUCCESS;

assert(Queue != nullptr);
UR_CHECK_ERROR(hipEventDestroy(EvEnd));

if (Queue->URFlags & UR_QUEUE_FLAG_PROFILING_ENABLE || isTimestampEvent()) {
UR_CHECK_ERROR(hipEventDestroy(EvQueued));
UR_CHECK_ERROR(hipEventDestroy(EvStart));
}

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urEventWait(uint32_t numEvents, const ur_event_handle_t *phEventWaitList) {
UR_ASSERT(numEvents > 0, UR_RESULT_ERROR_INVALID_VALUE);
Expand Down Expand Up @@ -291,15 +299,24 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventRelease(ur_event_handle_t hEvent) {
// decrement ref count. If it is 0, delete the event.
if (hEvent->decrementReferenceCount() == 0) {
std::unique_ptr<ur_event_handle_t_> event_ptr{hEvent};
ur_result_t Result = UR_RESULT_ERROR_INVALID_EVENT;
try {
Result = hEvent->release();
} catch (...) {
Result = UR_RESULT_ERROR_OUT_OF_RESOURCES;
if (!hEvent->backendHasOwnership()) {
return UR_RESULT_SUCCESS;
} else {
auto Queue = event_ptr->getQueue();
auto Context = event_ptr->getContext();

event_ptr->reset();
if (Queue) {
Queue->cache_event(event_ptr.release());
urQueueRelease(Queue);
}
urContextRelease(Context);
}
} catch (ur_result_t Err) {
return Err;
}
return Result;
}

return UR_RESULT_SUCCESS;
}

Expand Down
19 changes: 18 additions & 1 deletion source/adapters/hip/event.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,21 @@ struct ur_event_handle_t_ {
static ur_event_handle_t
makeNative(ur_command_t Type, ur_queue_handle_t Queue, hipStream_t Stream,
uint32_t StreamToken = std::numeric_limits<uint32_t>::max()) {
if (Queue->has_cached_events()) {
auto retEvent = Queue->get_cached_event();

retEvent->Stream = Stream;
retEvent->StreamToken = StreamToken;
retEvent->CommandType = Type;
retEvent->Queue = Queue;
retEvent->Context = Queue->Context;
retEvent->RefCount = 1;

urQueueRetain(retEvent->Queue);
urContextRetain(retEvent->Context);

return retEvent;
}
return new ur_event_handle_t_(Type, Queue->getContext(), Queue, Stream,
StreamToken);
}
Expand All @@ -91,7 +106,9 @@ struct ur_event_handle_t_ {
return new ur_event_handle_t_(context, eventNative);
}

ur_result_t release();
// Resets attributes of an event.
// Throws an error if its RefCount is not 0.
void reset();

~ur_event_handle_t_();

Expand Down
11 changes: 11 additions & 0 deletions source/adapters/hip/queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ void ur_queue_handle_t_::transferStreamWaitForBarrierIfNeeded(
}
}

ur_queue_handle_t_::~ur_queue_handle_t_() {
urContextRelease(Context);
urDeviceRelease(Device);

std::lock_guard<std::mutex> CacheGuard(CacheMutex);
while (!CachedEvents.empty()) {
std::unique_ptr<ur_event_handle_t_> Ev{CachedEvents.top()};
CachedEvents.pop();
}
}

hipStream_t ur_queue_handle_t_::getNextComputeStream(uint32_t *StreamToken) {
uint32_t Stream_i;
uint32_t Token;
Expand Down
Loading