Skip to content

Commit 1c4137c

Browse files
authored
[SYCL][CUDA] Fix usage of multiple backends in the same program (#1252)
Implementation of piEventSetCallback with tests GlueEvent uses now the correct plugins The SYCL RT code for GlueEvent calls now the right plugin to create the event that triggers the dependency chain. Renamed variables to clarify the source code and avoid confusions between Context and Plugin Signed-off-by: Ruyman Reyes <ruyman@codeplay.com> Signed-off-by: Stuart Adams <stuart.adams@codeplay.com> Signed-off-by: Steffen Larsen <steffen.larsen@codeplay.com>
1 parent 3ee80a5 commit 1c4137c

File tree

7 files changed

+237
-71
lines changed

7 files changed

+237
-71
lines changed

sycl/plugins/cuda/pi_cuda.cpp

Lines changed: 49 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ pi_result _pi_event::start() {
149149
}
150150

151151
isStarted_ = true;
152+
// let observers know that the event is "submitted"
153+
trigger_callback(get_execution_status());
152154
return result;
153155
}
154156

@@ -195,6 +197,22 @@ pi_result _pi_event::record() {
195197

196198
try {
197199
result = PI_CHECK_ERROR(cuEventRecord(evEnd_, cuStream));
200+
201+
result = cuda_piEventRetain(this);
202+
try {
203+
result = PI_CHECK_ERROR(cuLaunchHostFunc(
204+
cuStream,
205+
[](void *userData) {
206+
pi_event event = reinterpret_cast<pi_event>(userData);
207+
event->set_event_complete();
208+
cuda_piEventRelease(event);
209+
},
210+
this));
211+
} catch (...) {
212+
// If host function fails to enqueue we must release the event here
213+
result = cuda_piEventRelease(this);
214+
throw;
215+
}
198216
} catch (pi_result error) {
199217
result = error;
200218
}
@@ -215,6 +233,7 @@ pi_result _pi_event::wait() {
215233
if (is_native_event()) {
216234
try {
217235
retErr = PI_CHECK_ERROR(cuEventSynchronize(evEnd_));
236+
isCompleted_ = true;
218237
} catch (pi_result error) {
219238
retErr = error;
220239
}
@@ -226,30 +245,12 @@ pi_result _pi_event::wait() {
226245
retErr = PI_SUCCESS;
227246
}
228247

229-
return retErr;
230-
}
231-
232-
pi_event_status _pi_event::get_execution_status() const noexcept {
248+
auto is_success = retErr == PI_SUCCESS;
249+
auto status = is_success ? get_execution_status() : pi_int32(retErr);
233250

234-
if (!is_recorded()) {
235-
return PI_EVENT_SUBMITTED;
236-
}
237-
238-
if (is_native_event()) {
239-
// native event status
240-
241-
auto status = cuEventQuery(get());
242-
if (status == CUDA_ERROR_NOT_READY) {
243-
return PI_EVENT_RUNNING;
244-
} else if (status != CUDA_SUCCESS) {
245-
cl::sycl::detail::pi::die("Invalid CUDA event status");
246-
}
247-
return PI_EVENT_COMPLETE;
248-
} else {
249-
// user event status
251+
trigger_callback(status);
250252

251-
return is_completed() ? PI_EVENT_COMPLETE : PI_EVENT_RUNNING;
252-
}
253+
return retErr;
253254
}
254255

255256
// iterates over the event wait list, returns correct pi_result error codes.
@@ -2530,24 +2531,21 @@ pi_result cuda_piEventGetInfo(pi_event event, pi_event_info param_name,
25302531

25312532
switch (param_name) {
25322533
case PI_EVENT_INFO_COMMAND_QUEUE:
2533-
return getInfo<pi_queue>(param_value_size, param_value,
2534-
param_value_size_ret, event->get_queue());
2534+
return getInfo(param_value_size, param_value, param_value_size_ret,
2535+
event->get_queue());
25352536
case PI_EVENT_INFO_COMMAND_TYPE:
2536-
return getInfo<pi_command_type>(param_value_size, param_value,
2537-
param_value_size_ret,
2538-
event->get_command_type());
2537+
return getInfo(param_value_size, param_value, param_value_size_ret,
2538+
event->get_command_type());
25392539
case PI_EVENT_INFO_REFERENCE_COUNT:
2540-
return getInfo<pi_uint32>(param_value_size, param_value,
2541-
param_value_size_ret,
2542-
event->get_reference_count());
2540+
return getInfo(param_value_size, param_value, param_value_size_ret,
2541+
event->get_reference_count());
25432542
case PI_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
2544-
return getInfo<pi_event_status>(param_value_size, param_value,
2545-
param_value_size_ret,
2546-
event->get_execution_status());
2543+
return getInfo(param_value_size, param_value, param_value_size_ret,
2544+
static_cast<pi_event_status>(event->get_execution_status()));
25472545
}
25482546
case PI_EVENT_INFO_CONTEXT:
2549-
return getInfo<pi_context>(param_value_size, param_value,
2550-
param_value_size_ret, event->get_context());
2547+
return getInfo(param_value_size, param_value, param_value_size_ret,
2548+
event->get_context());
25512549
default:
25522550
PI_HANDLE_UNKNOWN_PARAM_NAME(param_name);
25532551
}
@@ -2582,13 +2580,21 @@ pi_result cuda_piEventGetProfilingInfo(
25822580
return {};
25832581
}
25842582

2585-
pi_result cuda_piEventSetCallback(
2586-
pi_event event, pi_int32 command_exec_callback_type,
2587-
void (*pfn_notify)(pi_event event, pi_int32 event_command_status,
2588-
void *user_data),
2589-
void *user_data) {
2590-
cl::sycl::detail::pi::die("cuda_piEventSetCallback not implemented");
2591-
return {};
2583+
pi_result cuda_piEventSetCallback(pi_event event,
2584+
pi_int32 command_exec_callback_type,
2585+
pfn_notify notify, void *user_data) {
2586+
2587+
assert(event);
2588+
assert(notify);
2589+
assert(command_exec_callback_type == PI_EVENT_SUBMITTED ||
2590+
command_exec_callback_type == PI_EVENT_RUNNING ||
2591+
command_exec_callback_type == PI_EVENT_COMPLETE);
2592+
event_callback callback(pi_event_status(command_exec_callback_type), notify,
2593+
user_data);
2594+
2595+
event->set_event_callback(callback);
2596+
2597+
return PI_SUCCESS;
25922598
}
25932599

25942600
pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
@@ -2601,7 +2607,7 @@ pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
26012607
}
26022608

26032609
if (execution_status == PI_EVENT_COMPLETE) {
2604-
return event->set_user_event_complete();
2610+
return event->set_event_complete();
26052611
} else if (execution_status < 0) {
26062612
// TODO: A negative integer value causes all enqueued commands that wait
26072613
// on this user event to be terminated.

sycl/plugins/cuda/pi_cuda.hpp

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,39 @@ struct _pi_queue {
235235
pi_uint32 get_reference_count() const noexcept { return refCount_; }
236236
};
237237

238+
typedef void (*pfn_notify)(pi_event event, pi_int32 eventCommandStatus,
239+
void *userData);
240+
241+
class event_callback {
242+
public:
243+
void trigger_callback(pi_event event, pi_int32 currentEventStatus) const {
244+
245+
auto validParameters = callback_ && event;
246+
247+
// As a pi_event_status value approaches 0, it gets closer to completion.
248+
// If the calling pi_event's status is less than or equal to the event
249+
// status the user is interested in, invoke the callback anyway. The event
250+
// will have passed through that state anyway.
251+
auto validStatus = currentEventStatus <= observedEventStatus_;
252+
253+
if (validParameters && validStatus) {
254+
255+
callback_(event, currentEventStatus, userData_);
256+
}
257+
}
258+
259+
event_callback(pi_event_status status, pfn_notify callback, void *userData)
260+
: observedEventStatus_{status}, callback_{callback}, userData_{userData} {
261+
}
262+
263+
pi_event_status get_status() const noexcept { return observedEventStatus_; }
264+
265+
private:
266+
pi_event_status observedEventStatus_;
267+
pfn_notify callback_;
268+
void *userData_;
269+
};
270+
238271
class _pi_event {
239272
public:
240273
using native_type = CUevent;
@@ -247,18 +280,39 @@ class _pi_event {
247280

248281
native_type get() const noexcept { return evEnd_; };
249282

250-
pi_result set_user_event_complete() noexcept {
283+
pi_result set_event_complete() noexcept {
251284

252285
if (isCompleted_) {
253286
return PI_INVALID_OPERATION;
254287
}
255288

256-
if (is_user_event()) {
257-
isRecorded_ = true;
258-
isCompleted_ = true;
259-
return PI_SUCCESS;
289+
isRecorded_ = true;
290+
isCompleted_ = true;
291+
292+
trigger_callback(get_execution_status());
293+
294+
return PI_SUCCESS;
295+
}
296+
297+
void trigger_callback(pi_int32 status) {
298+
299+
std::vector<event_callback> callbacks;
300+
301+
// Here we move all callbacks into local variable before we call them.
302+
// This is a defensive maneuver; if any of the callbacks attempt to
303+
// add additional callbacks, we will end up in a bad spot. Our mutex
304+
// will be locked twice and the vector will be modified as it is being
305+
// iterated over! By moving everything locally, we can call all of these
306+
// callbacks and let them modify the original vector without much worry.
307+
308+
{
309+
std::lock_guard<std::mutex> lock(mutex_);
310+
event_callbacks_.swap(callbacks);
311+
}
312+
313+
for (auto &event_callback : callbacks) {
314+
event_callback.trigger_callback(this, status);
260315
}
261-
return PI_INVALID_EVENT;
262316
}
263317

264318
pi_queue get_queue() const noexcept { return queue_; }
@@ -273,7 +327,27 @@ class _pi_event {
273327

274328
bool is_started() const noexcept { return isStarted_; }
275329

276-
pi_event_status get_execution_status() const noexcept;
330+
pi_int32 get_execution_status() const noexcept {
331+
332+
if (!is_recorded()) {
333+
return PI_EVENT_SUBMITTED;
334+
}
335+
336+
if (!is_completed()) {
337+
return PI_EVENT_RUNNING;
338+
}
339+
return PI_EVENT_COMPLETE;
340+
}
341+
342+
void set_event_callback(const event_callback &callback) {
343+
auto current_status = get_execution_status();
344+
if (current_status <= callback.get_status()) {
345+
callback.trigger_callback(this, current_status);
346+
} else {
347+
std::lock_guard<std::mutex> lock(mutex_);
348+
event_callbacks_.emplace_back(callback);
349+
}
350+
}
277351

278352
pi_context get_context() const noexcept { return context_; };
279353

@@ -343,6 +417,12 @@ class _pi_event {
343417
pi_context context_; // pi_context associated with the event. If this is a
344418
// native event, this will be the same context associated
345419
// with the queue_ member.
420+
421+
std::mutex mutex_; // Protect access to event_callbacks_. TODO: There might be
422+
// a lock-free data structure we can use here.
423+
std::vector<event_callback>
424+
event_callbacks_; // Callbacks that can be triggered when an event's state
425+
// changes.
346426
};
347427

348428
struct _pi_program {

sycl/source/detail/scheduler/commands.cpp

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -161,45 +161,50 @@ void EventCompletionClbk(RT::PiEvent, pi_int32, void *data) {
161161
EventImplPtr *Event = (reinterpret_cast<EventImplPtr *>(data));
162162
RT::PiEvent &EventHandle = (*Event)->getHandleRef();
163163
const detail::plugin &Plugin = (*Event)->getPlugin();
164-
Plugin.call<PiApiKind::piEventSetStatus>(EventHandle, CL_COMPLETE);
164+
Plugin.call<PiApiKind::piEventSetStatus>(EventHandle, PI_EVENT_COMPLETE);
165165
delete (Event);
166166
}
167167

168168
// Method prepares PI event's from list sycl::event's
169169
std::vector<EventImplPtr> Command::prepareEvents(ContextImplPtr Context) {
170170
std::vector<EventImplPtr> Result;
171171
std::vector<EventImplPtr> GlueEvents;
172-
for (EventImplPtr &Event : MDepsEvents) {
172+
for (EventImplPtr &DepEvent : MDepsEvents) {
173173
// Async work is not supported for host device.
174-
if (Event->is_host()) {
175-
Event->waitInternal();
174+
if (DepEvent->is_host()) {
175+
DepEvent->waitInternal();
176176
continue;
177177
}
178178
// The event handle can be null in case of, for example, alloca command,
179179
// which is currently synchrounious, so don't generate OpenCL event.
180-
if (Event->getHandleRef() == nullptr) {
180+
if (DepEvent->getHandleRef() == nullptr) {
181181
continue;
182182
}
183-
ContextImplPtr EventContext = Event->getContextImpl();
184-
const detail::plugin &Plugin = Event->getPlugin();
185-
// If contexts don't match - connect them using user event
186-
if (EventContext != Context && !Context->is_host()) {
183+
ContextImplPtr DepEventContext = DepEvent->getContextImpl();
187184

185+
// If contexts don't match - connect them using user event
186+
if (DepEventContext != Context && !Context->is_host()) {
188187
EventImplPtr GlueEvent(new detail::event_impl());
189188
GlueEvent->setContextImpl(Context);
189+
EventImplPtr *GlueEventCopy =
190+
new EventImplPtr(GlueEvent); // To increase the reference count by 1.
191+
190192
RT::PiEvent &GlueEventHandle = GlueEvent->getHandleRef();
193+
auto Plugin = Context->getPlugin();
194+
auto DepPlugin = DepEventContext->getPlugin();
195+
// Add an event on the current context that
196+
// is triggered when the DepEvent is complete
191197
Plugin.call<PiApiKind::piEventCreate>(Context->getHandleRef(),
192198
&GlueEventHandle);
193-
EventImplPtr *GlueEventCopy =
194-
new EventImplPtr(GlueEvent); // To increase the reference count by 1.
195-
Plugin.call<PiApiKind::piEventSetCallback>(
196-
Event->getHandleRef(), CL_COMPLETE, EventCompletionClbk,
199+
200+
DepPlugin.call<PiApiKind::piEventSetCallback>(
201+
DepEvent->getHandleRef(), PI_EVENT_COMPLETE, EventCompletionClbk,
197202
/*void *data=*/(GlueEventCopy));
198203
GlueEvents.push_back(GlueEvent);
199204
Result.push_back(std::move(GlueEvent));
200205
continue;
201206
}
202-
Result.push_back(Event);
207+
Result.push_back(DepEvent);
203208
}
204209
MDepsEvents.insert(MDepsEvents.end(), GlueEvents.begin(), GlueEvents.end());
205210
return Result;

sycl/test/basic_tests/buffer/buffer_dev_to_dev.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
// RUN: %GPU_RUN_PLACEHOLDER %t.out
55
// RUN: %ACC_RUN_PLACEHOLDER %t.out
66

7-
// TODO: pi_die: cuda_piEventSetCallback not implemented
8-
// XFAIL: cuda
9-
107
//==---------- buffer_dev_to_dev.cpp - SYCL buffer basic test --------------==//
118
//
129
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/DataMovement.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33
//
4-
// XFAIL: cuda
54
//==-------------------------- DataMovement.cpp ----------------------------==//
65
//
76
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

sycl/test/scheduler/MultipleDevices.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple -I %sycl_source_dir %s -o %t.out
22
// RUN: %t.out
33

4-
// TODO: pi_die: cuda_piEventSetCallback not implemented
5-
// XFAIL: cuda
6-
74
//===- MultipleDevices.cpp - Test checking multi-device execution --------===//
85
//
96
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.

0 commit comments

Comments
 (0)