Skip to content

Commit 0d8d1e0

Browse files
liutiexingliutiexing
andauthored
Os info (#38779)
* add align for WorkQueue * add spinlock * merge develop * merge * Add EventsWaiter * Revert "Add EventsWaiter" This reverts commit e206173. * os_info update * update * update * update * update * update * fix * update * update for windows * fix windows * update * update Co-authored-by: liutiexing <liutiexing@google.com>
1 parent b7bae93 commit 0d8d1e0

File tree

5 files changed

+234
-64
lines changed

5 files changed

+234
-64
lines changed

paddle/fluid/platform/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ ENDIF()
4747
cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS})
4848
cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info)
4949
cc_library(os_info SRCS os_info.cc DEPS enforce)
50+
cc_test(os_info_test SRCS os_info_test.cc DEPS os_info)
5051

5152
IF(WITH_GPU)
5253
nv_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade cuda_graph)

paddle/fluid/platform/os_info.cc

Lines changed: 164 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,193 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/platform/os_info.h"
16+
#include <functional>
17+
#include <mutex>
1618
#include <sstream>
19+
#include <thread>
20+
#include <vector>
1721
#if defined(__linux__)
1822
#include <sys/syscall.h>
1923
#include <sys/types.h>
2024
#include <unistd.h>
2125
#elif defined(_MSC_VER)
2226
#include <processthreadsapi.h>
2327
#endif
28+
#include "paddle/fluid/platform/macros.h" // import DISABLE_COPY_AND_ASSIGN
2429

2530
namespace paddle {
2631
namespace platform {
32+
namespace internal {
2733

28-
ThreadId::ThreadId() {
34+
static uint64_t main_tid =
35+
std::hash<std::thread::id>()(std::this_thread::get_id());
36+
37+
template <typename T>
38+
class ThreadDataRegistry {
39+
class ThreadDataHolder;
40+
41+
public:
42+
// Singleton
43+
static ThreadDataRegistry& GetInstance() {
44+
static ThreadDataRegistry instance;
45+
return instance;
46+
}
47+
48+
const T& GetCurrentThreadData() { return CurrentThreadData(); }
49+
50+
void SetCurrentThreadData(const T& val) {
51+
std::lock_guard<std::mutex> lock(lock_);
52+
CurrentThreadData() = val;
53+
}
54+
55+
// Returns current snapshot of all threads. Make sure there is no thread
56+
// create/destory when using it.
57+
template <typename = std::enable_if_t<std::is_copy_constructible<T>::value>>
58+
std::unordered_map<uint64_t, T> GetAllThreadDataByValue() {
59+
std::unordered_map<uint64_t, T> data_copy;
60+
std::lock_guard<std::mutex> lock(lock_);
61+
data_copy.reserve(tid_map_.size());
62+
for (auto& kv : tid_map_) {
63+
data_copy.emplace(kv.first, kv.second->GetData());
64+
}
65+
return std::move(data_copy);
66+
}
67+
68+
void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) {
69+
std::lock_guard<std::mutex> lock(lock_);
70+
tid_map_[tid] = tls_obj;
71+
}
72+
73+
void UnregisterData(uint64_t tid) {
74+
if (tid == main_tid) {
75+
return;
76+
}
77+
std::lock_guard<std::mutex> lock(lock_);
78+
tid_map_.erase(tid);
79+
}
80+
81+
private:
82+
class ThreadDataHolder {
83+
public:
84+
ThreadDataHolder() {
85+
tid_ = std::hash<std::thread::id>()(std::this_thread::get_id());
86+
ThreadDataRegistry::GetInstance().RegisterData(tid_, this);
87+
}
88+
89+
~ThreadDataHolder() {
90+
ThreadDataRegistry::GetInstance().UnregisterData(tid_);
91+
}
92+
93+
T& GetData() { return data_; }
94+
95+
private:
96+
uint64_t tid_;
97+
T data_;
98+
};
99+
100+
ThreadDataRegistry() = default;
101+
102+
DISABLE_COPY_AND_ASSIGN(ThreadDataRegistry);
103+
104+
T& CurrentThreadData() {
105+
static thread_local ThreadDataHolder thread_data;
106+
return thread_data.GetData();
107+
}
108+
109+
std::mutex lock_;
110+
std::unordered_map<uint64_t, ThreadDataHolder*> tid_map_; // not owned
111+
};
112+
113+
class InternalThreadId {
114+
public:
115+
InternalThreadId();
116+
117+
const ThreadId& GetTid() const { return id_; }
118+
119+
private:
120+
ThreadId id_;
121+
};
122+
123+
InternalThreadId::InternalThreadId() {
29124
// C++ std tid
30-
std_tid_ = std::hash<std::thread::id>()(std::this_thread::get_id());
125+
id_.std_tid = std::hash<std::thread::id>()(std::this_thread::get_id());
31126
// system tid
32127
#if defined(__linux__)
33-
sys_tid_ = syscall(SYS_gettid);
128+
id_.sys_tid = static_cast<uint64_t>(syscall(SYS_gettid));
34129
#elif defined(_MSC_VER)
35-
sys_tid_ = GetCurrentThreadId();
36-
#else // unsupported platforms
37-
sys_tid_ = 0;
130+
id_.sys_tid = static_cast<uint64_t>(::GetCurrentThreadId());
131+
#else // unsupported platforms, use std_tid
132+
id_.sys_tid = id_.std_tid;
38133
#endif
39134
// cupti tid
40135
std::stringstream ss;
41136
ss << std::this_thread::get_id();
42-
cupti_tid_ = static_cast<uint32_t>(std::stoull(ss.str()));
137+
id_.cupti_tid = static_cast<uint32_t>(std::stoull(ss.str()));
138+
}
139+
140+
} // namespace internal
141+
142+
uint64_t GetCurrentThreadSysId() {
143+
return internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance()
144+
.GetCurrentThreadData()
145+
.GetTid()
146+
.sys_tid;
43147
}
44148

45-
ThreadIdRegistry::~ThreadIdRegistry() {
46-
std::lock_guard<std::mutex> lock(lock_);
47-
for (auto id_pair : id_map_) {
48-
delete id_pair.second;
149+
uint64_t GetCurrentThreadStdId() {
150+
return internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance()
151+
.GetCurrentThreadData()
152+
.GetTid()
153+
.std_tid;
154+
}
155+
156+
ThreadId GetCurrentThreadId() {
157+
return internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance()
158+
.GetCurrentThreadData()
159+
.GetTid();
160+
}
161+
162+
std::unordered_map<uint64_t, ThreadId> GetAllThreadIds() {
163+
auto tids =
164+
internal::ThreadDataRegistry<internal::InternalThreadId>::GetInstance()
165+
.GetAllThreadDataByValue();
166+
std::unordered_map<uint64_t, ThreadId> res;
167+
for (const auto& kv : tids) {
168+
res[kv.first] = kv.second.GetTid();
49169
}
170+
return res;
171+
}
172+
173+
static constexpr const char* kDefaultThreadName = "unset";
174+
175+
std::string GetCurrentThreadName() {
176+
const auto& thread_name =
177+
internal::ThreadDataRegistry<std::string>::GetInstance()
178+
.GetCurrentThreadData();
179+
return thread_name.empty() ? kDefaultThreadName : thread_name;
180+
}
181+
182+
std::unordered_map<uint64_t, std::string> GetAllThreadNames() {
183+
return internal::ThreadDataRegistry<std::string>::GetInstance()
184+
.GetAllThreadDataByValue();
185+
}
186+
187+
bool SetCurrentThreadName(const std::string& name) {
188+
auto& instance = internal::ThreadDataRegistry<std::string>::GetInstance();
189+
const auto& cur_name = instance.GetCurrentThreadData();
190+
if (!cur_name.empty() || cur_name == kDefaultThreadName) {
191+
return false;
192+
}
193+
instance.SetCurrentThreadData(name);
194+
return true;
195+
}
196+
197+
uint32_t GetProcessId() {
198+
#if defined(_MSC_VER)
199+
return static_cast<uint32_t>(GetCurrentProcessId());
200+
#else
201+
return static_cast<uint32_t>(getpid());
202+
#endif
50203
}
51204

52205
} // namespace platform

paddle/fluid/platform/os_info.h

Lines changed: 28 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,12 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include <mutex>
18-
#include <thread>
17+
#include <string>
1918
#include <unordered_map>
20-
#include "paddle/fluid/platform/enforce.h" // import LIKELY
21-
#include "paddle/fluid/platform/macros.h" // import DISABLE_COPY_AND_ASSIGN
22-
#include "paddle/fluid/platform/port.h"
2319
#ifdef _POSIX_C_SOURCE
2420
#include <time.h>
2521
#endif
22+
#include "paddle/fluid/platform/port.h"
2623

2724
namespace paddle {
2825
namespace platform {
@@ -41,59 +38,38 @@ inline uint64_t PosixInNsec() {
4138
}
4239

4340
// All kinds of Ids for OS thread
44-
class ThreadId {
45-
public:
46-
ThreadId();
41+
struct ThreadId {
42+
uint64_t std_tid = 0; // std::hash<std::thread::id>
43+
uint64_t sys_tid = 0; // OS-specific, Linux: gettid
44+
uint32_t cupti_tid = 0; // thread_id used by Nvidia CUPTI
45+
};
4746

48-
uint64_t MainTid() const { return SysTid(); }
47+
// Better performance than GetCurrentThreadId
48+
uint64_t GetCurrentThreadStdId();
4949

50-
uint64_t StdTid() const { return std_tid_; }
50+
// Better performance than GetCurrentThreadId
51+
uint64_t GetCurrentThreadSysId();
5152

52-
uint32_t CuptiTid() const { return cupti_tid_; }
53+
ThreadId GetCurrentThreadId();
5354

54-
uint64_t SysTid() const { return sys_tid_ != 0 ? sys_tid_ : std_tid_; }
55+
// Return the map from StdTid to ThreadId
56+
// Returns current snapshot of all threads. Make sure there is no thread
57+
// create/destory when using it.
58+
std::unordered_map<uint64_t, ThreadId> GetAllThreadIds();
5559

56-
private:
57-
uint64_t std_tid_ = 0; // std::hash<std::thread::id>
58-
uint32_t cupti_tid_ = 0; // thread_id used by Nvidia CUPTI
59-
uint64_t sys_tid_ = 0; // OS-specific, Linux: gettid
60-
};
60+
// Returns 'unset' if SetCurrentThreadName is never called.
61+
std::string GetCurrentThreadName();
6162

62-
class ThreadIdRegistry {
63-
public:
64-
// singleton
65-
static ThreadIdRegistry& GetInstance() {
66-
static ThreadIdRegistry instance;
67-
return instance;
68-
}
69-
70-
const ThreadId* GetThreadId(uint64_t std_id) {
71-
std::lock_guard<std::mutex> lock(lock_);
72-
if (LIKELY(id_map_.find(std_id) != id_map_.end())) {
73-
return id_map_[std_id];
74-
}
75-
return nullptr;
76-
}
77-
78-
const ThreadId& CurrentThreadId() {
79-
static thread_local ThreadId* tid_ = nullptr;
80-
if (LIKELY(tid_ != nullptr)) {
81-
return *tid_;
82-
}
83-
tid_ = new ThreadId;
84-
std::lock_guard<std::mutex> lock(lock_);
85-
id_map_[tid_->StdTid()] = tid_;
86-
return *tid_;
87-
}
88-
89-
private:
90-
ThreadIdRegistry() = default;
91-
DISABLE_COPY_AND_ASSIGN(ThreadIdRegistry);
92-
~ThreadIdRegistry();
93-
94-
std::mutex lock_;
95-
std::unordered_map<uint64_t, ThreadId*> id_map_;
96-
};
63+
// Return the map from StdTid to ThreadName
64+
// Returns current snapshot of all threads. Make sure there is no thread
65+
// create/destory when using it.
66+
std::unordered_map<uint64_t, std::string> GetAllThreadNames();
67+
68+
// Thread name is immutable, only the first call will succeed.
69+
// Returns false on failure.
70+
bool SetCurrentThreadName(const std::string& name);
71+
72+
uint32_t GetProcessId();
9773

9874
} // namespace platform
9975
} // namespace paddle

paddle/fluid/platform/os_info_test.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/fluid/platform/os_info.h"
15+
#include <thread>
16+
#include "gtest/gtest.h"
17+
18+
TEST(ThreadInfo, TestThreadIdUtils) {
19+
using paddle::platform::GetCurrentThreadStdId;
20+
using paddle::platform::GetCurrentThreadId;
21+
using paddle::platform::GetAllThreadIds;
22+
EXPECT_EQ(std::hash<std::thread::id>()(std::this_thread::get_id()),
23+
GetCurrentThreadId().std_tid);
24+
auto ids = GetAllThreadIds();
25+
EXPECT_TRUE(ids.find(GetCurrentThreadStdId()) != ids.end());
26+
}
27+
28+
TEST(ThreadInfo, TestThreadNameUtils) {
29+
using paddle::platform::GetCurrentThreadStdId;
30+
using paddle::platform::GetCurrentThreadName;
31+
using paddle::platform::SetCurrentThreadName;
32+
using paddle::platform::GetAllThreadNames;
33+
EXPECT_EQ("unset", GetCurrentThreadName());
34+
EXPECT_TRUE(SetCurrentThreadName("MainThread"));
35+
EXPECT_FALSE(SetCurrentThreadName("MainThread"));
36+
auto names = GetAllThreadNames();
37+
EXPECT_TRUE(names.find(GetCurrentThreadStdId()) != names.end());
38+
EXPECT_EQ("MainThread", names[GetCurrentThreadStdId()]);
39+
EXPECT_EQ("MainThread", GetCurrentThreadName());
40+
}

paddle/fluid/platform/profiler/host_event_recorder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace paddle {
1616
namespace platform {
1717

1818
ThreadEventRecorder::ThreadEventRecorder() {
19-
thread_id_ = ThreadIdRegistry::GetInstance().CurrentThreadId().MainTid();
19+
thread_id_ = GetCurrentThreadSysId();
2020
HostEventRecorder::GetInstance().RegisterThreadRecorder(thread_id_, this);
2121
}
2222

0 commit comments

Comments
 (0)