Skip to content

Commit 2309149

Browse files
committed
use flat_hash_map and small_vector in kernel factory
1 parent 3f5f789 commit 2309149

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

paddle/tcmpt/core/kernel_factory.h

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616

1717
#include <ostream>
1818
#include <string>
19-
#include <unordered_map>
2019
#include <utility>
2120

2221
#include "paddle/tcmpt/core/backend.h"
2322
#include "paddle/tcmpt/core/dtype.h"
2423
#include "paddle/tcmpt/core/kernel_def.h"
2524
#include "paddle/tcmpt/core/layout.h"
25+
#include "paddle/utils/flat_hash_map.h"
26+
#include "paddle/utils/small_vector.h"
2627

2728
// See Note [ Why still include the fluid headers? ]
2829
#include "paddle/fluid/platform/enforce.h"
@@ -209,25 +210,30 @@ class KernelArgsDef {
209210
attribute_defs_.emplace_back(AttributeArgDef(type_index));
210211
}
211212

212-
const std::vector<TensorArgDef>& input_defs() const { return input_defs_; }
213+
const paddle::SmallVector<TensorArgDef>& input_defs() const {
214+
return input_defs_;
215+
}
213216

214-
const std::vector<TensorArgDef>& output_defs() const { return output_defs_; }
217+
const paddle::SmallVector<TensorArgDef>& output_defs() const {
218+
return output_defs_;
219+
}
215220

216-
const std::vector<AttributeArgDef>& attribute_defs() const {
221+
const paddle::SmallVector<AttributeArgDef>& attribute_defs() const {
217222
return attribute_defs_;
218223
}
219224

220-
std::vector<TensorArgDef>& input_defs() { return input_defs_; }
225+
paddle::SmallVector<TensorArgDef>& input_defs() { return input_defs_; }
221226

222-
std::vector<TensorArgDef>& output_defs() { return output_defs_; }
227+
paddle::SmallVector<TensorArgDef>& output_defs() { return output_defs_; }
223228

224-
std::vector<AttributeArgDef>& attribute_defs() { return attribute_defs_; }
229+
paddle::SmallVector<AttributeArgDef>& attribute_defs() {
230+
return attribute_defs_;
231+
}
225232

226233
private:
227-
// TODO(chenweihang): replaced by paddle::small_vector
228-
std::vector<TensorArgDef> input_defs_{{}};
229-
std::vector<TensorArgDef> output_defs_{{}};
230-
std::vector<AttributeArgDef> attribute_defs_{{}};
234+
paddle::SmallVector<TensorArgDef> input_defs_{{}};
235+
paddle::SmallVector<TensorArgDef> output_defs_{{}};
236+
paddle::SmallVector<AttributeArgDef> attribute_defs_{{}};
231237
};
232238

233239
class Kernel {
@@ -263,10 +269,10 @@ class Kernel {
263269
class KernelFactory {
264270
public:
265271
// replaced by paddle::flat_hash_map later
266-
using KernelMap =
267-
std::unordered_map<KernelName,
268-
std::unordered_map<KernelKey, Kernel, KernelKey::Hash>,
269-
KernelName::Hash>;
272+
using KernelMap = paddle::flat_hash_map<
273+
KernelName,
274+
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>,
275+
KernelName::Hash>;
270276

271277
static KernelFactory& Instance();
272278

0 commit comments

Comments
 (0)