Skip to content

Commit 0e1191f

Browse files
authored
[Phi] Add phi device context pool (#40635)
* add phi device context pool * change year * fix compile error * fix operator = error * refine init impl * polish details * refine init impl
1 parent 276017b commit 0e1191f

File tree

8 files changed

+194
-26
lines changed

8 files changed

+194
-26
lines changed

paddle/fluid/platform/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ endif()
117117
cc_library(cudnn_workspace_helper SRCS cudnn_workspace_helper.cc DEPS boost)
118118

119119
# seperate init from device_context to avoid cycle dependencies
120-
cc_library(init SRCS init.cc DEPS device_context custom_kernel)
120+
cc_library(init SRCS init.cc DEPS device_context custom_kernel context_pool)
121121

122122
# memcpy depends on device_context, here add deps individually for
123123
# avoiding cycle dependencies

paddle/fluid/platform/device_context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,11 @@ class DeviceContextPool {
916916

917917
size_t size() const { return device_contexts_.size(); }
918918

919+
const std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>&
920+
device_contexts() const {
921+
return device_contexts_;
922+
}
923+
919924
private:
920925
static DeviceContextPool* pool;
921926
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>

paddle/phi/api/include/context_pool.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include "paddle/phi/common/place.h"
18+
#include "paddle/phi/core/macros.h"
19+
#include "paddle/utils/flat_hash_map.h"
20+
21+
namespace phi {
22+
class DeviceContext;
23+
class CPUContext;
24+
class GPUContext;
25+
} // namespace phi
26+
27+
namespace paddle {
28+
namespace experimental {
29+
30+
template <AllocationType T>
31+
struct DefaultDeviceContextType;
32+
33+
template <>
34+
struct DefaultDeviceContextType<AllocationType::CPU> {
35+
using TYPE = phi::CPUContext;
36+
};
37+
38+
template <>
39+
struct DefaultDeviceContextType<AllocationType::GPU> {
40+
using TYPE = phi::GPUContext;
41+
};
42+
43+
/**
44+
* The DeviceContextPool here is just a mirror of the DeviceContextPool in
45+
* fluid, and does not manage the life cycle of the DeviceContext.
46+
* It is mainly used for external custom operator calls and high-performance
47+
* C++ APIs.
48+
*
49+
* Since DeviceContextPool in fluid is a global singleton, it always exists
50+
* in program running, so DeviceContextPool here can always access the correct
51+
* DeviceContext pointer.
52+
*
53+
* In order not to depend on the fluid's DeviceContextPool,
54+
* the DeviceContextPool here needs to be initialized in the fluid, and cannot
55+
* be initialized by itself.
56+
*/
57+
class DeviceContextPool {
58+
public:
59+
static DeviceContextPool& Instance();
60+
61+
const phi::DeviceContext* Get(const Place& place) const;
62+
63+
phi::DeviceContext* GetMutable(const Place& place);
64+
65+
template <AllocationType T>
66+
const typename DefaultDeviceContextType<T>::TYPE* Get(
67+
const Place& place) const {
68+
return reinterpret_cast<const typename DefaultDeviceContextType<T>::TYPE*>(
69+
Get(place));
70+
}
71+
72+
private:
73+
DeviceContextPool();
74+
paddle::flat_hash_map<Place, const phi::DeviceContext*, Place::Hash>
75+
context_map_;
76+
77+
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
78+
};
79+
80+
} // namespace experimental
81+
} // namespace paddle

paddle/phi/api/lib/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,9 @@ add_custom_command(
135135

136136
cc_library(op_meta_info SRCS op_meta_info.cc DEPS phi_tensor_raw)
137137
cc_library(wrapped_infermeta SRCS ${wrapped_infermeta_source_file} DEPS phi)
138+
cc_library(context_pool SRCS context_pool.cc DEPS phi_context phi_enforce place)
138139

139-
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory)
140+
cc_library(kernel_dispatch SRCS kernel_dispatch.cc DEPS phi_tensor_raw phi_context kernel_factory context_pool)
140141
cc_library(api_gen_utils SRCS api_gen_utils.cc DEPS phi_tensor_raw selected_rows sparse_csr_tensor sparse_coo_tensor)
141142
cc_library(phi_data_transform SRCS data_transform.cc DEPS phi_tensor_raw transfer_layout_kernel cast_kernel data_device_transform)
142143
cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)

paddle/phi/api/lib/context_pool.cc

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/api/include/context_pool.h"
16+
17+
#include "paddle/phi/backends/all_context.h"
18+
#include "paddle/phi/core/enforce.h"
19+
20+
namespace paddle {
21+
namespace experimental {
22+
23+
DeviceContextPool& DeviceContextPool::Instance() {
24+
static DeviceContextPool g_device_context_pool;
25+
return g_device_context_pool;
26+
}
27+
28+
const phi::DeviceContext* DeviceContextPool::Get(const Place& place) const {
29+
auto it = context_map_.find(place);
30+
PADDLE_ENFORCE_NE(
31+
it,
32+
context_map_.end(),
33+
phi::errors::NotFound("The DeviceContext of %s does not exists.", place));
34+
return it->second;
35+
}
36+
37+
phi::DeviceContext* DeviceContextPool::GetMutable(const Place& place) {
38+
return const_cast<phi::DeviceContext*>(Get(place));
39+
}
40+
41+
DeviceContextPool::DeviceContextPool() {
42+
// We need to make sure that the correct value exists
43+
// whenever we get the DeviceContext from DeviceContextPool
44+
const auto& device_contexts =
45+
paddle::platform::DeviceContextPool::Instance().device_contexts();
46+
for (const auto& pair : device_contexts) {
47+
// only get CPU and GPU DeviceContext now, add other DeviceContext type
48+
// later if needed
49+
if (platform::is_cpu_place(pair.first)
50+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
51+
||
52+
platform::is_gpu_place(pair.first)) {
53+
#else
54+
) {
55+
#endif
56+
const phi::DeviceContext* dev_ctx = pair.second.get().get();
57+
VLOG(3) << "Init phi DeviceContextPool: insert {" << pair.first << ", "
58+
<< dev_ctx << "}";
59+
context_map_[pair.first] = dev_ctx;
60+
}
61+
}
62+
}
63+
64+
} // namespace experimental
65+
} // namespace paddle

paddle/phi/api/lib/kernel_dispatch.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/phi/api/lib/kernel_dispatch.h"
1616

17+
#include "paddle/phi/api/include/context_pool.h"
1718
#include "paddle/phi/core/compat/convert_utils.h"
1819

1920
namespace paddle {
@@ -52,8 +53,8 @@ std::size_t CountLeadingZeros(uint64_t val) {
5253
} // namespace detail
5354

5455
phi::DeviceContext* GetDeviceContextByBackend(phi::Backend backend) {
55-
auto& pool = paddle::platform::DeviceContextPool::Instance();
56-
return pool.Get(phi::TransToPhiPlace(backend));
56+
auto& pool = paddle::experimental::DeviceContextPool::Instance();
57+
return pool.GetMutable(phi::TransToPhiPlace(backend));
5758
}
5859

5960
DataType ParseDataType(DataType dtype) { return dtype; }

paddle/phi/common/place.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,20 @@ std::string GetGlobalDeviceType(size_t device_type_id) {
9292
return global_registered_device_type[device_type_id];
9393
}
9494

95+
constexpr static int kAllocationTypeBitLength = 8;
96+
constexpr static int kDeviceTypeIDBitLength = 8;
97+
constexpr static int kDeviceIDBitLength = 8;
98+
99+
uint32_t Place::Hash::operator()(const Place &place) const {
100+
uint32_t hash_value = 0;
101+
// |----31-24------|-----23-16------|-----15-08----|---7-0----|
102+
// | For extension | AllocationType | DeviceTypeID | DeviceID |
103+
hash_value |= (static_cast<uint8_t>(place.alloc_type_)
104+
<< (kDeviceIDBitLength + kDeviceTypeIDBitLength));
105+
hash_value |=
106+
(static_cast<uint8_t>(place.device_type_id_) << kDeviceIDBitLength);
107+
hash_value |= static_cast<uint8_t>(place.device);
108+
return hash_value;
109+
}
110+
95111
} // namespace phi

paddle/phi/common/place.h

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,31 +73,23 @@ class Place {
7373

7474
std::string DebugString() const;
7575

76+
struct Hash {
77+
// Note: Now the number of bits we need does not exceed 32 bits, so there is
78+
// no need to use 64 bits. If needed in the future, it can be expanded,
79+
// but now we don’t over-design.
80+
uint32_t operator()(const Place& place) const;
81+
};
82+
83+
uint32_t HashValue() const { return Hash()(*this); }
84+
7685
inline bool operator==(const Place& rhs) const {
77-
if (alloc_type_ != rhs.GetType()) {
78-
return false;
79-
}
80-
if (alloc_type_ == AllocationType::CPU ||
81-
alloc_type_ == AllocationType::GPUPINNED ||
82-
alloc_type_ == AllocationType::NPUPINNED) {
83-
return true;
84-
}
85-
if (alloc_type_ == AllocationType::CUSTOM) {
86-
return device_type_id_ == rhs.device_type_id_ &&
87-
device == rhs.GetDeviceId();
88-
}
89-
return device == rhs.GetDeviceId();
86+
return HashValue() == rhs.HashValue();
87+
}
88+
inline bool operator!=(const Place& rhs) const {
89+
return HashValue() != rhs.HashValue();
9090
}
91-
inline bool operator!=(const Place& rhs) const { return !(*this == rhs); }
9291
inline bool operator<(const Place& rhs) const {
93-
if (alloc_type_ != rhs.GetType()) {
94-
return static_cast<int>(alloc_type_) < static_cast<int>(rhs.GetType());
95-
}
96-
if (alloc_type_ == AllocationType::CUSTOM &&
97-
device_type_id_ != rhs.device_type_id_) {
98-
return device_type_id_ < rhs.device_type_id_;
99-
}
100-
return device < rhs.GetDeviceId();
92+
return HashValue() < rhs.HashValue();
10193
}
10294

10395
public:
@@ -206,3 +198,10 @@ class CustomPlace : public Place {
206198
std::ostream& operator<<(std::ostream&, const Place&);
207199

208200
} // namespace phi
201+
202+
namespace paddle {
203+
namespace experimental {
204+
using AllocationType = phi::AllocationType;
205+
using Place = phi::Place;
206+
} // namespace experimental
207+
} // namespace paddle

0 commit comments

Comments
 (0)