|
16 | 16 |
|
17 | 17 | #include <ostream> |
18 | 18 | #include <string> |
19 | | -#include <unordered_map> |
20 | 19 | #include <utility> |
21 | 20 |
|
22 | 21 | #include "paddle/tcmpt/core/backend.h" |
23 | 22 | #include "paddle/tcmpt/core/dtype.h" |
24 | 23 | #include "paddle/tcmpt/core/kernel_def.h" |
25 | 24 | #include "paddle/tcmpt/core/layout.h" |
| 25 | +#include "paddle/utils/flat_hash_map.h" |
| 26 | +#include "paddle/utils/small_vector.h" |
26 | 27 |
|
27 | 28 | // See Note [ Why still include the fluid headers? ] |
28 | 29 | #include "paddle/fluid/platform/enforce.h" |
@@ -209,25 +210,30 @@ class KernelArgsDef { |
209 | 210 | attribute_defs_.emplace_back(AttributeArgDef(type_index)); |
210 | 211 | } |
211 | 212 |
|
212 | | - const std::vector<TensorArgDef>& input_defs() const { return input_defs_; } |
| 213 | + const paddle::SmallVector<TensorArgDef>& input_defs() const { |
| 214 | + return input_defs_; |
| 215 | + } |
213 | 216 |
|
214 | | - const std::vector<TensorArgDef>& output_defs() const { return output_defs_; } |
| 217 | + const paddle::SmallVector<TensorArgDef>& output_defs() const { |
| 218 | + return output_defs_; |
| 219 | + } |
215 | 220 |
|
216 | | - const std::vector<AttributeArgDef>& attribute_defs() const { |
| 221 | + const paddle::SmallVector<AttributeArgDef>& attribute_defs() const { |
217 | 222 | return attribute_defs_; |
218 | 223 | } |
219 | 224 |
|
220 | | - std::vector<TensorArgDef>& input_defs() { return input_defs_; } |
| 225 | + paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; } |
221 | 226 |
|
222 | | - std::vector<TensorArgDef>& output_defs() { return output_defs_; } |
| 227 | + paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; } |
223 | 228 |
|
224 | | - std::vector<AttributeArgDef>& attribute_defs() { return attribute_defs_; } |
| 229 | + paddle::SmallVector<AttributeArgDef>& attribute_defs() { |
| 230 | + return attribute_defs_; |
| 231 | + } |
225 | 232 |
|
226 | 233 | private: |
227 | | - // TODO(chenweihang): replaced by paddle::small_vector |
228 | | - std::vector<TensorArgDef> input_defs_{{}}; |
229 | | - std::vector<TensorArgDef> output_defs_{{}}; |
230 | | - std::vector<AttributeArgDef> attribute_defs_{{}}; |
| 234 | + paddle::SmallVector<TensorArgDef> input_defs_{{}}; |
| 235 | + paddle::SmallVector<TensorArgDef> output_defs_{{}}; |
| 236 | + paddle::SmallVector<AttributeArgDef> attribute_defs_{{}}; |
231 | 237 | }; |
232 | 238 |
|
233 | 239 | class Kernel { |
@@ -263,10 +269,10 @@ class Kernel { |
263 | 269 | class KernelFactory { |
264 | 270 | public: |
265 | 271 | // replaced by paddle::flat_hash_map later |
266 | | - using KernelMap = |
267 | | - std::unordered_map<KernelName, |
268 | | - std::unordered_map<KernelKey, Kernel, KernelKey::Hash>, |
269 | | - KernelName::Hash>; |
| 272 | + using KernelMap = paddle::flat_hash_map< |
| 273 | + KernelName, |
| 274 | + paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>, |
| 275 | + KernelName::Hash>; |
270 | 276 |
|
271 | 277 | static KernelFactory& Instance(); |
272 | 278 |
|
|
0 commit comments