@@ -47,16 +47,16 @@ inline uint64_t GetTimeInNsec() {
4747}
4848
4949Event::Event (EventKind kind, std::string name, uint32_t thread_id,
50- DeviceContext* dev_ctx)
50+ const DeviceContext* dev_ctx)
5151 : kind_(kind), name_(name), thread_id_(thread_id), has_cuda_(false ) {
5252#ifdef PADDLE_WITH_CUDA
53- auto * cuda_dev_ctx = static_cast <const CUDADeviceContext*>(dev_ctx);
54- if (cuda_dev_ctx) {
53+ has_cuda_ = dev_ctx ? platform::is_gpu_place (dev_ctx->GetPlace ()) : false ;
54+ if (has_cuda_) {
55+ auto * cuda_dev_ctx = static_cast <const CUDADeviceContext*>(dev_ctx);
5556 PADDLE_ENFORCE (cudaGetDevice (&device_));
5657 PADDLE_ENFORCE (cudaEventCreate (&event_));
5758 auto stream = cuda_dev_ctx->stream ();
5859 PADDLE_ENFORCE (cudaEventRecord (event_, stream));
59- has_cuda_ = true ;
6060 }
6161#endif
6262 cpu_ns_ = GetTimeInNsec ();
@@ -114,19 +114,20 @@ inline EventList& GetEventList() {
114114 return *g_event_list;
115115}
116116
117- void Mark (const std::string& name, DeviceContext* dev_ctx) {
117+ void Mark (const std::string& name, const DeviceContext* dev_ctx) {
118118 GetEventList ().Record (EventKind::kMark , name, g_thread_id, dev_ctx);
119119}
120120
121- void PushEvent (const std::string& name, DeviceContext* dev_ctx) {
121+ void PushEvent (const std::string& name, const DeviceContext* dev_ctx) {
122122 GetEventList ().Record (EventKind::kPushRange , name, g_thread_id, dev_ctx);
123123}
124124
125- void PopEvent (const std::string& name, DeviceContext* dev_ctx) {
125+ void PopEvent (const std::string& name, const DeviceContext* dev_ctx) {
126126 GetEventList ().Record (EventKind::kPopRange , name, g_thread_id, dev_ctx);
127127}
128128
129- RecordEvent::RecordEvent (const std::string& name, DeviceContext* dev_ctx) {
129+ RecordEvent::RecordEvent (const std::string& name,
130+ const DeviceContext* dev_ctx) {
130131 if (g_state == ProfilerState::kDisabled ) return ;
131132 dev_ctx_ = dev_ctx;
132133 name_ = name;
@@ -155,6 +156,7 @@ void EnableProfiler(ProfilerState state) {
155156 DeviceContext* dev_ctx = new CUDADeviceContext (CUDAPlace (d));
156157 Mark (" _cuda_startup_" , dev_ctx);
157158 dev_ctx->Wait ();
159+ delete dev_ctx;
158160 });
159161 }
160162 }
@@ -163,21 +165,36 @@ void EnableProfiler(ProfilerState state) {
163165 Mark (" _start_profiler_" , nullptr );
164166}
165167
166- std::vector<std::vector<Event>> DisableProfiler () {
167- PADDLE_ENFORCE (g_state != ProfilerState::kDisabled ,
168- " Can't disable profiling, since it's not starting." );
169- // Mark the profiling stop.
170- Mark (" _stop_profiler_" , nullptr );
171- g_state = ProfilerState::kDisabled ;
172- std::vector<std::vector<Event>> result;
168+ void ResetProfiler () {
173169 std::lock_guard<std::mutex> guard (g_all_event_lists_mutex);
170+ for (auto it = g_all_event_lists.begin (); it != g_all_event_lists.end ();
171+ ++it) {
172+ (*it)->Clear ();
173+ }
174+ }
175+
176+ std::vector<std::vector<Event>> GetAllEvents () {
177+ std::lock_guard<std::mutex> guard (g_all_event_lists_mutex);
178+ std::vector<std::vector<Event>> result;
174179 for (auto it = g_all_event_lists.begin (); it != g_all_event_lists.end ();
175180 ++it) {
176181 result.emplace_back ((*it)->Reduce ());
177182 }
178183 return result;
179184}
180185
186+ void DisableProfiler (EventSortingKey sorted_key) {
187+ PADDLE_ENFORCE (g_state != ProfilerState::kDisabled ,
188+ " Can't disable profiling, since it's not starting." );
189+ // Mark the profiling stop.
190+ Mark (" _stop_profiler_" , nullptr );
191+ g_state = ProfilerState::kDisabled ;
192+
193+ std::vector<std::vector<Event>> all_events = GetAllEvents ();
194+ ParseEvents (all_events, sorted_key);
195+ ResetProfiler ();
196+ }
197+
181198void ParseEvents (std::vector<std::vector<Event>>& events,
182199 EventSortingKey sorted_by) {
183200 if (g_profiler_place == " " ) return ;
@@ -291,12 +308,12 @@ void ParseEvents(std::vector<std::vector<Event>>& events,
291308 }
292309
293310 // Print report
294- PrintProfilingReport (events_table, sorted_domain, max_name_width + 4 , 12 );
311+ PrintProfiler (events_table, sorted_domain, max_name_width + 4 , 12 );
295312}
296313
297- void PrintProfilingReport (std::vector<std::vector<EventItem>>& events_table,
298- std::string& sorted_domain, const size_t name_width,
299- const size_t data_width) {
314+ void PrintProfiler (std::vector<std::vector<EventItem>>& events_table,
315+ std::string& sorted_domain, const size_t name_width,
316+ const size_t data_width) {
300317 // Output header information
301318 std::cout << " \n ------------------------->"
302319 << " Profiling Report "
0 commit comments