Skip to content

Commit

Permalink
add elapsed_time function for Event (PaddlePaddle#67605)
Browse files Browse the repository at this point in the history
* add elapsed_time function for Event

* fix example

* fix unitest
GuoxiaWang authored Aug 28, 2024
1 parent 2f4212d commit dddda15
Showing 3 changed files with 41 additions and 1 deletion.
26 changes: 26 additions & 0 deletions paddle/fluid/pybind/cuda_streams_py.cc
Original file line number Diff line number Diff line change
@@ -374,6 +374,32 @@ void BindCudaStream(py::module *m_ptr) {
>>> event = paddle.device.cuda.Event()
>>> is_done = event.query()
)DOC")
.def(
"elapsed_time",
[](phi::CudaEvent &self, phi::CudaEvent &end_event) {
return self.ElapsedTime(&end_event);
},
R"DOC(
Returns the time elapsed in milliseconds after the event was
recorded and before the end_event was recorded.
Returns: A int which indicates the elapsed time.
Examples:
.. code-block:: python
>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle
>>> paddle.set_device('gpu')
>>> e1 = paddle.device.Event(enable_timing=True)
>>> e1.record()
>>> e2 = paddle.device.Event(enable_timing=True)
>>> e2.record()
>>> e1.elapsed_time(e2)
)DOC")
.def(
"synchronize",
14 changes: 14 additions & 0 deletions paddle/phi/api/profiler/event.h
Original file line number Diff line number Diff line change
@@ -196,6 +196,20 @@ class CudaEvent {
return false;
}

float ElapsedTime(CudaEvent *end_event) {
float milliseconds = 0;
#ifdef PADDLE_WITH_HIP
hipEventSynchronize(end_event->GetRawCudaEvent());
PADDLE_ENFORCE_GPU_SUCCESS(hipEventElapsedTime(
&milliseconds, event_, end_event->GetRawCudaEvent()));
#else
cudaEventSynchronize(end_event->GetRawCudaEvent());
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventElapsedTime(
&milliseconds, event_, end_event->GetRawCudaEvent()));
#endif
return milliseconds;
}

void Synchronize() {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipEventSynchronize(event_));
2 changes: 1 addition & 1 deletion python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
@@ -604,7 +604,7 @@ def elapsed_time(self, end_event: Event) -> int:
>>> e1.elapsed_time(e2)
'''
return 0
return self.event_base.elapsed_time(end_event.event_base)

def synchronize(self) -> None:
'''

0 comments on commit dddda15

Please sign in to comment.