@@ -13,40 +13,193 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/platform/os_info.h"
16
+ #include < functional>
17
+ #include < mutex>
16
18
#include < sstream>
19
+ #include < thread>
20
+ #include < vector>
17
21
#if defined(__linux__)
18
22
#include < sys/syscall.h>
19
23
#include < sys/types.h>
20
24
#include < unistd.h>
21
25
#elif defined(_MSC_VER)
22
26
#include < processthreadsapi.h>
23
27
#endif
28
+ #include " paddle/fluid/platform/macros.h" // import DISABLE_COPY_AND_ASSIGN
24
29
25
30
namespace paddle {
26
31
namespace platform {
32
+ namespace internal {
27
33
28
- ThreadId::ThreadId () {
34
+ static uint64_t main_tid =
35
+ std::hash<std::thread::id>()(std::this_thread::get_id());
36
+
37
+ template <typename T>
38
+ class ThreadDataRegistry {
39
+ class ThreadDataHolder ;
40
+
41
+ public:
42
+ // Singleton
43
+ static ThreadDataRegistry& GetInstance () {
44
+ static ThreadDataRegistry instance;
45
+ return instance;
46
+ }
47
+
48
+ const T& GetCurrentThreadData () { return CurrentThreadData (); }
49
+
50
+ void SetCurrentThreadData (const T& val) {
51
+ std::lock_guard<std::mutex> lock (lock_);
52
+ CurrentThreadData () = val;
53
+ }
54
+
55
+ // Returns current snapshot of all threads. Make sure there is no thread
56
+ // create/destory when using it.
57
+ template <typename = std::enable_if_t <std::is_copy_constructible<T>::value>>
58
+ std::unordered_map<uint64_t , T> GetAllThreadDataByValue () {
59
+ std::unordered_map<uint64_t , T> data_copy;
60
+ std::lock_guard<std::mutex> lock (lock_);
61
+ data_copy.reserve (tid_map_.size ());
62
+ for (auto & kv : tid_map_) {
63
+ data_copy.emplace (kv.first , kv.second ->GetData ());
64
+ }
65
+ return std::move (data_copy);
66
+ }
67
+
68
+ void RegisterData (uint64_t tid, ThreadDataHolder* tls_obj) {
69
+ std::lock_guard<std::mutex> lock (lock_);
70
+ tid_map_[tid] = tls_obj;
71
+ }
72
+
73
+ void UnregisterData (uint64_t tid) {
74
+ if (tid == main_tid) {
75
+ return ;
76
+ }
77
+ std::lock_guard<std::mutex> lock (lock_);
78
+ tid_map_.erase (tid);
79
+ }
80
+
81
+ private:
82
+ class ThreadDataHolder {
83
+ public:
84
+ ThreadDataHolder () {
85
+ tid_ = std::hash<std::thread::id>()(std::this_thread::get_id ());
86
+ ThreadDataRegistry::GetInstance ().RegisterData (tid_, this );
87
+ }
88
+
89
+ ~ThreadDataHolder () {
90
+ ThreadDataRegistry::GetInstance ().UnregisterData (tid_);
91
+ }
92
+
93
+ T& GetData () { return data_; }
94
+
95
+ private:
96
+ uint64_t tid_;
97
+ T data_;
98
+ };
99
+
100
+ ThreadDataRegistry () = default ;
101
+
102
+ DISABLE_COPY_AND_ASSIGN (ThreadDataRegistry);
103
+
104
+ T& CurrentThreadData () {
105
+ static thread_local ThreadDataHolder thread_data;
106
+ return thread_data.GetData ();
107
+ }
108
+
109
+ std::mutex lock_;
110
+ std::unordered_map<uint64_t , ThreadDataHolder*> tid_map_; // not owned
111
+ };
112
+
113
+ class InternalThreadId {
114
+ public:
115
+ InternalThreadId ();
116
+
117
+ const ThreadId& GetTid () const { return id_; }
118
+
119
+ private:
120
+ ThreadId id_;
121
+ };
122
+
123
+ InternalThreadId::InternalThreadId () {
29
124
// C++ std tid
30
- std_tid_ = std::hash<std::thread::id>()(std::this_thread::get_id ());
125
+ id_. std_tid = std::hash<std::thread::id>()(std::this_thread::get_id ());
31
126
// system tid
32
127
#if defined(__linux__)
33
- sys_tid_ = syscall (SYS_gettid);
128
+ id_. sys_tid = static_cast < uint64_t >( syscall (SYS_gettid) );
34
129
#elif defined(_MSC_VER)
35
- sys_tid_ = GetCurrentThreadId ();
36
- #else // unsupported platforms
37
- sys_tid_ = 0 ;
130
+ id_. sys_tid = static_cast < uint64_t >(:: GetCurrentThreadId () );
131
+ #else // unsupported platforms, use std_tid
132
+ id_. sys_tid = id_. std_tid ;
38
133
#endif
39
134
// cupti tid
40
135
std::stringstream ss;
41
136
ss << std::this_thread::get_id ();
42
- cupti_tid_ = static_cast <uint32_t >(std::stoull (ss.str ()));
137
+ id_.cupti_tid = static_cast <uint32_t >(std::stoull (ss.str ()));
138
+ }
139
+
140
+ } // namespace internal
141
+
142
+ uint64_t GetCurrentThreadSysId () {
143
+ return internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance ()
144
+ .GetCurrentThreadData ()
145
+ .GetTid ()
146
+ .sys_tid ;
43
147
}
44
148
45
- ThreadIdRegistry::~ThreadIdRegistry () {
46
- std::lock_guard<std::mutex> lock (lock_);
47
- for (auto id_pair : id_map_) {
48
- delete id_pair.second ;
149
+ uint64_t GetCurrentThreadStdId () {
150
+ return internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance ()
151
+ .GetCurrentThreadData ()
152
+ .GetTid ()
153
+ .std_tid ;
154
+ }
155
+
156
+ ThreadId GetCurrentThreadId () {
157
+ return internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance ()
158
+ .GetCurrentThreadData ()
159
+ .GetTid ();
160
+ }
161
+
162
+ std::unordered_map<uint64_t , ThreadId> GetAllThreadIds () {
163
+ auto tids =
164
+ internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance ()
165
+ .GetAllThreadDataByValue ();
166
+ std::unordered_map<uint64_t , ThreadId> res;
167
+ for (const auto & kv : tids) {
168
+ res[kv.first ] = kv.second .GetTid ();
49
169
}
170
+ return res;
171
+ }
172
+
173
+ static constexpr const char * kDefaultThreadName = " unset" ;
174
+
175
+ std::string GetCurrentThreadName () {
176
+ const auto & thread_name =
177
+ internal::ThreadDataRegistry<std::string>::GetInstance ()
178
+ .GetCurrentThreadData ();
179
+ return thread_name.empty () ? kDefaultThreadName : thread_name;
180
+ }
181
+
182
+ std::unordered_map<uint64_t , std::string> GetAllThreadNames () {
183
+ return internal::ThreadDataRegistry<std::string>::GetInstance ()
184
+ .GetAllThreadDataByValue ();
185
+ }
186
+
187
+ bool SetCurrentThreadName (const std::string& name) {
188
+ auto & instance = internal::ThreadDataRegistry<std::string>::GetInstance ();
189
+ const auto & cur_name = instance.GetCurrentThreadData ();
190
+ if (!cur_name.empty () || cur_name == kDefaultThreadName ) {
191
+ return false ;
192
+ }
193
+ instance.SetCurrentThreadData (name);
194
+ return true ;
195
+ }
196
+
197
+ uint32_t GetProcessId () {
198
+ #if defined(_MSC_VER)
199
+ return static_cast <uint32_t >(GetCurrentProcessId ());
200
+ #else
201
+ return static_cast <uint32_t >(getpid ());
202
+ #endif
50
203
}
51
204
52
205
} // namespace platform
0 commit comments