Skip to content

Commit 94a56cc

Browse files
authored
[IR]Add BFloat16 in IrContextImpl (PaddlePaddle#54281)
1 parent c71198f commit 94a56cc

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

paddle/ir/core/builtin_type.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class BFloat16Type : public Type {
4242
using Type::Type;
4343

4444
DECLARE_TYPE_UTILITY_FUNCTOR(BFloat16Type, TypeStorage);
45+
46+
static BFloat16Type get(IrContext *context);
4547
};
4648

4749
class Int8Type : public Type {

paddle/ir/core/dialect.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ Dialect::~Dialect() = default;
2222

2323
void Dialect::RegisterInterface(std::unique_ptr<DialectInterface> interface) {
2424
VLOG(4) << "Register interface into dialect" << std::endl;
25-
auto it = registered_interfaces_.emplace(interface->interface_id(),
26-
std::move(interface));
27-
(void)it;
25+
registered_interfaces_.emplace(interface->interface_id(),
26+
std::move(interface));
2827
}
2928

3029
DialectInterface::~DialectInterface() = default;

paddle/ir/core/ir_context.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ class IrContextImpl {
151151
// TypeStorage uniquer and cache instances.
152152
StorageManager registed_type_storage_manager_;
153153
// Cache some built-in type objects.
154+
BFloat16Type bfp16_type;
154155
Float16Type fp16_type;
155156
Float32Type fp32_type;
156157
Float64Type fp64_type;
@@ -185,6 +186,7 @@ IrContext::IrContext() : impl_(new IrContextImpl()) {
185186
GetOrRegisterDialect<BuiltinDialect>();
186187
VLOG(4) << "==============================================";
187188

189+
impl_->bfp16_type = TypeManager::get<BFloat16Type>(this);
188190
impl_->fp16_type = TypeManager::get<Float16Type>(this);
189191
impl_->fp32_type = TypeManager::get<Float32Type>(this);
190192
impl_->fp64_type = TypeManager::get<Float64Type>(this);
@@ -319,6 +321,10 @@ const AbstractAttribute &AbstractAttribute::lookup(TypeId type_id,
319321
}
320322
}
321323

324+
BFloat16Type BFloat16Type::get(IrContext *ctx) {
325+
return ctx->impl().bfp16_type;
326+
}
327+
322328
Float16Type Float16Type::get(IrContext *ctx) { return ctx->impl().fp16_type; }
323329

324330
Float32Type Float32Type::get(IrContext *ctx) { return ctx->impl().fp32_type; }

0 commit comments

Comments
 (0)