Skip to content

Commit 4a86800

Browse files
chore(gpu): use active streams in int_radix_lut
1 parent 1513c3b commit 4a86800

File tree

7 files changed

+408
-409
lines changed

7 files changed

+408
-409
lines changed

backends/tfhe-cuda-backend/cuda/include/helper_multi_gpu.h

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,135 @@ struct CudaStreams {
183183
}
184184
};
185185

186+
struct CudaStreamsWorkersWaitFirstBarrier {
187+
private:
188+
cudaEvent_t _event;
189+
CudaStreams _streams;
190+
191+
CudaStreamsWorkersWaitFirstBarrier(
192+
const CudaStreamsWorkersWaitFirstBarrier &) {
193+
} // Prevent copy-construction
194+
CudaStreamsWorkersWaitFirstBarrier &
195+
operator=(const CudaStreamsWorkersWaitFirstBarrier &) {
196+
return *this;
197+
} // Prevent assignment
198+
public:
199+
void create_on(const CudaStreams &streams) {
200+
_streams = streams;
201+
_event = cuda_create_event(streams.gpu_index(0));
202+
}
203+
204+
CudaStreamsWorkersWaitFirstBarrier() { _event = nullptr; };
205+
206+
void workers_wait_for_gpu_0() {
207+
GPU_ASSERT(
208+
_event != nullptr,
209+
"CudaStreamsWorkersWaitFirstBarrier: must call create_on before use");
210+
211+
if (_streams.count() > 1) {
212+
cuda_event_record(_event, _streams.stream(0), _streams.gpu_index(0));
213+
for (int j = 1; j < _streams.count(); j++) {
214+
cuda_stream_wait_event(_streams.stream(j), _event,
215+
_streams.gpu_index(j));
216+
}
217+
}
218+
}
219+
220+
void user_streams_wait_for_gpu_0(const CudaStreams &user_streams) {
221+
GPU_ASSERT(
222+
_event != nullptr,
223+
"CudaStreamsWorkersWaitFirstBarrier: must call create_on before use");
224+
225+
cuda_event_record(_event, user_streams.stream(0),
226+
user_streams.gpu_index(0));
227+
for (int j = 1; j < user_streams.count(); j++) {
228+
cuda_stream_wait_event(user_streams.stream(j), _event,
229+
user_streams.gpu_index(j));
230+
}
231+
}
232+
233+
void release() {
234+
cuda_event_destroy(_event, _streams.gpu_index(0));
235+
_event = nullptr;
236+
}
237+
238+
~CudaStreamsWorkersWaitFirstBarrier() {
239+
GPU_ASSERT(_event == nullptr, "CudaStreamsWorkersWaitFirstBarrier: must "
240+
"call release before destruction");
241+
}
242+
};
243+
244+
struct CudaStreamsFirstWaitsWorkersBarrier {
245+
private:
246+
std::vector<cudaEvent_t> _events;
247+
CudaStreams _lut_streams;
248+
249+
CudaStreamsFirstWaitsWorkersBarrier(
250+
const CudaStreamsFirstWaitsWorkersBarrier &) {
251+
} // Prevent copy-construction
252+
CudaStreamsFirstWaitsWorkersBarrier &
253+
operator=(const CudaStreamsFirstWaitsWorkersBarrier &) {
254+
return *this;
255+
} // Prevent assignment
256+
public:
257+
void create_on(const CudaStreams &streams) {
258+
GPU_ASSERT(streams.count() > 1, "CudaStreamsFirstWaitsWorkersBarrier: "
259+
"Attempted to create on single GPU");
260+
_lut_streams = streams;
261+
_events.resize(streams.count());
262+
for (int i = 0; i < streams.count(); i++) {
263+
_events[i] = cuda_create_event(streams.gpu_index(i));
264+
}
265+
}
266+
267+
CudaStreamsFirstWaitsWorkersBarrier(){};
268+
269+
void gpu_0_wait_for_user_streams(const CudaStreams &user_streams) {
270+
GPU_ASSERT(
271+
!_events.empty(),
272+
"CudaStreamsFirstWaitsWorkersBarrier: must call create_on before use");
273+
GPU_ASSERT(
274+
user_streams.count() <= _events.size(),
275+
"CudaStreamsFirstWaitsWorkersBarrier: trying to synchronize too many "
276+
"streams. "
277+
"The barrier was created on a LUT that had %lu active streams, while "
278+
"the user stream set has %u streams",
279+
_events.size(), user_streams.count());
280+
281+
if (user_streams.count() > 1) {
282+
// Worker GPUs record their events
283+
for (int j = 0; j < user_streams.count(); j++) {
284+
GPU_ASSERT(_lut_streams.gpu_index(j) == user_streams.gpu_index(j),
285+
"CudaStreamsFirstWaitsWorkersBarrier: The user stream "
286+
"set GPU[%d]=%u while the LUT stream set GPU[%d]=%u",
287+
j, user_streams.gpu_index(j), j, _lut_streams.gpu_index(j));
288+
289+
cuda_event_record(_events[j], user_streams.stream(j),
290+
user_streams.gpu_index(j));
291+
}
292+
293+
// GPU 0 waits for all workers
294+
for (int j = 0; j < user_streams.count(); j++) {
295+
cuda_stream_wait_event(user_streams.stream(0), _events[j],
296+
user_streams.gpu_index(0));
297+
}
298+
}
299+
}
300+
301+
void release() {
302+
for (int j = 0; j < _lut_streams.count(); j++) {
303+
cuda_event_destroy(_events[j], _lut_streams.gpu_index(j));
304+
}
305+
306+
_events.clear();
307+
}
308+
309+
~CudaStreamsFirstWaitsWorkersBarrier() {
310+
GPU_ASSERT(_events.empty(),
311+
"CudaStreamsFirstWaitsWorkersBarrier: must "
312+
"call release before destruction: events size = %lu",
313+
_events.size());
314+
}
315+
};
316+
186317
#endif

0 commit comments

Comments
 (0)