Skip to content

Commit d91d758

Browse files
authored
[IR] OpTrait & OpInterface & OpInfo (PaddlePaddle#52846)
* add OpTrait OpInterface ValueIterator TypeList * refine code * refine code * refine code * add opinfo * add typeid copy constructor * add trait interface construct method for opinfo_impl * add trait interface construct method for opinfo_impl * add trait interface construct method for opinfo_impl * add trait interface construct method for opinfo_impl * add trait interface construct method for opinfo_impl * add create * add member func for opinfo * fix compile bug * add op interface in ircontext * fix compile bug * fix compile bug * refine code * fix compile bug * add ut * refine ut * refine code of opinfo_impl * delete unused code * add dyncast for operation * refine comment * refine opinfo_impl * delete unused code * refine code by comment * refine code * refine code * refine code for registerOp * refine opfin create * refine code of search method of ircontext * refine op attribute * change opinfo_map key from type_id to string
1 parent b729512 commit d91d758

17 files changed

+898
-73
lines changed

paddle/ir/builtin_attribute_storage.cc

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,16 @@ DictionaryAttributeStorage::ParamKey DictionaryAttributeStorage::GetAsKey()
5959
}
6060

6161
Attribute DictionaryAttributeStorage::GetValue(const StrAttribute &name) const {
62-
if (size_ > 0) {
63-
size_t left = 0;
64-
size_t right = size_ - 1;
65-
size_t mid = 0;
66-
while (left <= right) {
67-
mid = (left + right) / 2;
68-
if (data_[mid].name() == name) {
69-
return data_[mid].value();
70-
} else if (data_[mid].name() < name) {
71-
left = mid + 1;
72-
} else {
73-
right = mid - 1;
74-
}
62+
size_t left = 0;
63+
size_t right = size_;
64+
while (left < right) {
65+
size_t mid = left + (right - left) / 2;
66+
if (data_[mid].name() == name) {
67+
return data_[mid].value();
68+
} else if (data_[mid].name() < name) {
69+
left = mid + 1;
70+
} else {
71+
right = mid;
7572
}
7673
}
7774
return nullptr;

paddle/ir/dialect.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,8 @@ void Dialect::RegisterAttribute(ir::AbstractAttribute &&abstract_attribute) {
3131
this->ir_context()->RegisterAbstractAttribute(
3232
new_abstract_attribute->type_id(), new_abstract_attribute);
3333
}
34+
35+
void Dialect::RegisterOp(const std::string &name, OpInfoImpl *op_info) {
36+
this->ir_context()->RegisterOpInfo(name, op_info);
37+
}
3438
} // namespace ir

paddle/ir/dialect.h

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/ir/attribute_base.h"
1818
#include "paddle/ir/ir_context.h"
19+
#include "paddle/ir/op_info_impl.h"
1920
#include "paddle/ir/type_base.h"
2021

2122
namespace ir {
@@ -45,17 +46,19 @@ class Dialect {
4546
(void)std::initializer_list<int>{0, (RegisterType<Args>(), 0)...};
4647
}
4748

48-
///
49-
/// \brief Register type of class T.
50-
///
5149
template <typename T>
5250
void RegisterType() {
5351
VLOG(4) << "Type registered into Dialect. --->";
54-
ir::AbstractType *abstract_type =
55-
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
56-
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
57-
abstract_type);
58-
ir::TypeManager::RegisterType<T>(this->ir_context());
52+
// if (this->ir_context()->registed_abstract_type().count(
53+
// ir::TypeId::get<T>()) == 0) {
54+
if (this->ir_context()->GetRegisteredAbstractType(ir::TypeId::get<T>()) ==
55+
nullptr) {
56+
ir::AbstractType *abstract_type =
57+
new ir::AbstractType(std::move(ir::AbstractType::get<T>(*this)));
58+
this->ir_context()->RegisterAbstractType(ir::TypeId::get<T>(),
59+
abstract_type);
60+
ir::TypeManager::RegisterType<T>(this->ir_context());
61+
}
5962
VLOG(4) << "----------------------------------";
6063
}
6164

@@ -78,24 +81,42 @@ class Dialect {
7881
(void)std::initializer_list<int>{0, (RegisterAttribute<Args>(), 0)...};
7982
}
8083

81-
///
82-
/// \brief Register attribute of class T.
83-
///
8484
template <typename T>
8585
void RegisterAttribute() {
8686
VLOG(4) << "Attribute registered into Dialect. --->";
87-
ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute(
88-
std::move(ir::AbstractAttribute::get<T>(*this)));
89-
this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(),
90-
abstract_attribute);
91-
ir::AttributeManager::RegisterAttribute<T>(this->ir_context());
87+
if (this->ir_context()->GetRegisteredAbstractAttribute(
88+
ir::TypeId::get<T>()) == nullptr) {
89+
ir::AbstractAttribute *abstract_attribute = new ir::AbstractAttribute(
90+
std::move(ir::AbstractAttribute::get<T>(*this)));
91+
this->ir_context()->RegisterAbstractAttribute(ir::TypeId::get<T>(),
92+
abstract_attribute);
93+
ir::AttributeManager::RegisterAttribute<T>(this->ir_context());
94+
}
9295
VLOG(4) << "----------------------------------";
9396
}
9497

98+
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
99+
95100
///
96-
/// \brief Register abstract_attribute into context.
101+
/// \brief Register Operation methods.
97102
///
98-
void RegisterAttribute(ir::AbstractAttribute &&abstract_attribute);
103+
template <typename... Args>
104+
void RegisterOps() {
105+
(void)std::initializer_list<int>{0, (RegisterOp<Args>(), 0)...};
106+
}
107+
108+
template <typename ConcertOp>
109+
void RegisterOp() {
110+
std::string name = this->name() + "." + std::string(ConcertOp::name());
111+
VLOG(4) << "Op " << name << " registered into Dialect. --->";
112+
if (this->ir_context()->GetRegisteredOpInfo(name) == nullptr) {
113+
ir::OpInfoImpl *op_info = ir::OpInfoImpl::create<ConcertOp>(this);
114+
this->ir_context()->RegisterOpInfo(name, op_info);
115+
}
116+
VLOG(4) << "----------------------------------";
117+
}
118+
119+
void RegisterOp(const std::string &name, OpInfoImpl *op_info);
99120

100121
private:
101122
std::string name_;

paddle/ir/ir_context.cc

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/ir/builtin_dialect.h"
2121
#include "paddle/ir/builtin_type.h"
2222
#include "paddle/ir/dialect.h"
23+
#include "paddle/ir/op_info_impl.h"
2324
#include "paddle/ir/spin_lock.h"
2425
#include "paddle/ir/type_base.h"
2526

@@ -46,6 +47,11 @@ class IrContextImpl {
4647
delete dialect_map.second;
4748
}
4849
registed_dialect_.clear();
50+
51+
for (auto &op_map : registed_op_infos_) {
52+
op_map.second->destroy();
53+
}
54+
registed_op_infos_.clear();
4955
}
5056

5157
void RegisterAbstractType(ir::TypeId type_id, AbstractType *abstract_type) {
@@ -93,6 +99,25 @@ class IrContextImpl {
9399
return nullptr;
94100
}
95101

102+
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
103+
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
104+
VLOG(4) << "Register an operation of: [Name=" << name
105+
<< ", OpInfoImpl ptr=" << opinfo << "].";
106+
registed_op_infos_.emplace(name, opinfo);
107+
}
108+
109+
OpInfoImpl *GetOpInfo(const std::string &name) {
110+
std::lock_guard<ir::SpinLock> guard(registed_op_infos_lock_);
111+
auto iter = registed_op_infos_.find(name);
112+
if (iter != registed_op_infos_.end()) {
113+
VLOG(4) << "Fonund a cached operation of: [name=" << name
114+
<< ", OpInfoImpl ptr=" << iter->second << "].";
115+
return iter->second;
116+
}
117+
LOG(WARNING) << "No cache found operation of: [Name=" << name << "].";
118+
return nullptr;
119+
}
120+
96121
void RegisterDialect(std::string name, Dialect *dialect) {
97122
std::lock_guard<ir::SpinLock> guard(registed_dialect_lock_);
98123
VLOG(4) << "Register a dialect of: [name=" << name
@@ -135,6 +160,10 @@ class IrContextImpl {
135160
std::unordered_map<std::string, Dialect *> registed_dialect_;
136161
ir::SpinLock registed_dialect_lock_;
137162

163+
// The Op registered in the context.
164+
std::unordered_map<std::string, OpInfoImpl *> registed_op_infos_;
165+
ir::SpinLock registed_op_infos_lock_;
166+
138167
ir::SpinLock destructor_lock_;
139168
};
140169

@@ -165,9 +194,12 @@ StorageManager &IrContext::type_storage_manager() {
165194
return impl().registed_type_storage_manager_;
166195
}
167196

168-
std::unordered_map<TypeId, AbstractType *>
169-
&IrContext::registed_abstracted_type() {
170-
return impl().registed_abstract_types_;
197+
AbstractType *IrContext::GetRegisteredAbstractType(TypeId id) {
198+
auto search = impl().registed_abstract_types_.find(id);
199+
if (search != impl().registed_abstract_types_.end()) {
200+
return search->second;
201+
}
202+
return nullptr;
171203
}
172204

173205
void IrContext::RegisterAbstractAttribute(
@@ -179,9 +211,12 @@ StorageManager &IrContext::attribute_storage_manager() {
179211
return impl().registed_attribute_storage_manager_;
180212
}
181213

182-
std::unordered_map<TypeId, AbstractAttribute *>
183-
&IrContext::registed_abstracted_attribute() {
184-
return impl().registed_abstract_attributes_;
214+
AbstractAttribute *IrContext::GetRegisteredAbstractAttribute(TypeId id) {
215+
auto search = impl().registed_abstract_attributes_.find(id);
216+
if (search != impl().registed_abstract_attributes_.end()) {
217+
return search->second;
218+
}
219+
return nullptr;
185220
}
186221

187222
Dialect *IrContext::GetOrRegisterDialect(
@@ -216,6 +251,17 @@ Dialect *IrContext::GetRegisteredDialect(const std::string &dialect_name) {
216251
return nullptr;
217252
}
218253

254+
OpInfoImpl *IrContext::GetRegisteredOpInfo(const std::string &name) {
255+
OpInfoImpl *rtn = impl().GetOpInfo(name);
256+
return rtn ? rtn : nullptr;
257+
}
258+
259+
void IrContext::RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo) {
260+
if (impl().GetOpInfo(name) == nullptr) {
261+
impl().RegisterOpInfo(name, opinfo);
262+
}
263+
}
264+
219265
const AbstractType &AbstractType::lookup(TypeId type_id, IrContext *ctx) {
220266
auto &impl = ctx->impl();
221267
AbstractType *abstract_type = impl.GetAbstractType(type_id);

paddle/ir/ir_context.h

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class AbstractType;
2626
class AbstractAttribute;
2727
class TypeId;
2828
class Dialect;
29+
class OpInfoImpl;
2930

3031
///
3132
/// \brief IrContext is a global parameterless class used to store and manage
@@ -47,7 +48,7 @@ class IrContext {
4748
IrContextImpl &impl() { return *impl_; }
4849

4950
///
50-
/// \brief Register an AbstractType to IrContext
51+
/// \brief Register an AbstractType to IrContext.
5152
///
5253
/// \param type_id The type id of the AbstractType.
5354
/// \param abstract_type AbstractType* provided by user.
@@ -64,13 +65,9 @@ class IrContext {
6465
StorageManager &type_storage_manager();
6566

6667
///
67-
/// \brief Returns the storage uniquer used for constructing TypeStorage
68-
/// instances.
69-
///
70-
/// \return The storage uniquer used for constructing TypeStorage
71-
/// instances.
68+
/// \brief Get registered AbstractType from IrContext.
7269
///
73-
std::unordered_map<TypeId, AbstractType *> &registed_abstracted_type();
70+
AbstractType *GetRegisteredAbstractType(TypeId id);
7471

7572
///
7673
/// \brief Register an AbstractAttribute to IrContext
@@ -91,14 +88,16 @@ class IrContext {
9188
StorageManager &attribute_storage_manager();
9289

9390
///
94-
/// \brief Returns the storage uniquer used for constructing AttributeStorage
95-
/// instances.
91+
/// \brief Get registered AbstractAttribute from IrContext.
9692
///
97-
/// \return The storage uniquer used for constructing AttributeStorage
98-
/// instances.
93+
AbstractAttribute *GetRegisteredAbstractAttribute(TypeId id);
94+
95+
///
96+
/// \brief Get or register operaiton.
9997
///
100-
std::unordered_map<TypeId, AbstractAttribute *>
101-
&registed_abstracted_attribute();
98+
void RegisterOpInfo(const std::string &name, OpInfoImpl *opinfo);
99+
100+
OpInfoImpl *GetRegisteredOpInfo(const std::string &name);
102101

103102
///
104103
/// \brief Get the dialect of the DialectT class in the context, ff not found,

paddle/ir/op_base.h

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,59 @@
1515
#pragma once
1616

1717
#include "paddle/ir/operation.h"
18+
#include "paddle/ir/utils.h"
1819

1920
namespace ir {
2021
class OpBase {
2122
public:
22-
Operation *operation() { return operation_; }
23+
explicit OpBase(const Operation *operation) : operation_(operation) {}
2324

24-
explicit operator bool() { return operation() != nullptr; }
25+
const Operation *operation() const { return operation_; }
2526

26-
operator Operation *() const { return operation_; }
27+
explicit operator bool() const { return operation() != nullptr; }
2728

28-
Operation *operator->() const { return operation_; }
29+
operator const Operation *() const { return operation_; }
2930

30-
protected:
31-
explicit OpBase(Operation *operation) : operation_(operation) {}
31+
const Operation *operator->() const { return operation_; }
3232

3333
private:
34-
Operation *operation_;
34+
const Operation *operation_; // Not owned
35+
};
36+
37+
///
38+
/// \brief OpTrait
39+
///
40+
template <class ConcreteTrait>
41+
class OpTraitBase : public OpBase {
42+
public:
43+
explicit OpTraitBase(const Operation *op) : OpBase(op) {}
44+
45+
static TypeId GetTraitId() { return TypeId::get<ConcreteTrait>(); }
46+
};
47+
48+
///
49+
/// \brief OpInterface
50+
///
51+
template <typename ConcreteInterface>
52+
class OpInterfaceBase : public OpBase {
53+
public:
54+
// explicit OpInterfaceBase(Operation *op) : OpBase(op) {}
55+
56+
explicit OpInterfaceBase(const Operation *op) : OpBase(op) {}
57+
58+
static TypeId GetInterfaceId() { return TypeId::get<ConcreteInterface>(); }
59+
};
60+
61+
template <typename ConcreteOp, class... TraitOrInterface>
62+
class Op : public OpBase {
63+
public:
64+
using OpBase::OpBase;
65+
66+
using TraitList =
67+
typename Filter<OpTraitBase, std::tuple<TraitOrInterface...>>::Type;
68+
69+
using InterfaceList =
70+
typename Filter<OpInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
3571
};
3672

3773
} // namespace ir

0 commit comments

Comments
 (0)