Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Support Region Clone in Operation::Clone #60590

Merged
merged 10 commits into from
Jan 9, 2024
2 changes: 1 addition & 1 deletion paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class IR_API Block {
friend class Region;
void SetParent(Region *parent);

// Take out corresponding Operation and its ownershipe.
// Take out corresponding Operation and its ownership.
friend class Operation;
Operation *Take(Operation *op);

Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#pragma once
#include <unordered_map>
#include "paddle/common/enforce.h"
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/value.h"

namespace pir {

Expand Down
10 changes: 8 additions & 2 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
IR_ENFORCE(!options.IsCloneRegions() || num_regions_ <= 0,
"Operation CloneRegions is unimplemented currently.");
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Expand All @@ -160,6 +158,14 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
for (uint32_t i = 0; i < num_results_; ++i) {
ir_mapping.Add(result(i), new_op->result(i));
}

if (options.IsCloneRegions()) {
// clone regions recursively
for (uint32_t i = 0; i < num_regions_; ++i) {
this->region(i).CloneInto(new_op->region(i), ir_mapping);
}
}

return new_op;
}

Expand Down
49 changes: 49 additions & 0 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,55 @@ Region::Iterator Region::erase(ConstIterator position) {
return blocks_.erase(position);
}

void Region::CloneInto(Region &other, IrMapping &ir_mapping) const {
if (empty()) {
return;
}
other.clear();
// clone blocks and block arguments.
for (auto iter = blocks_.begin(); iter != blocks_.end(); ++iter) {
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
auto new_block = new Block;
for (const auto &arg : (*iter)->args()) {
ir_mapping.Add(arg, new_block->AddArgument(arg.type()));
}
other.push_back(new_block);
}
// clone operations of each block, but not set mapped operands nor clone
// regions
auto clone_options = CloneOptions(false, false);
{
auto iter = blocks_.begin();
auto new_iter = other.begin();
for (; iter != blocks_.end(); ++iter, ++new_iter) {
const Block &block = **iter;
Block &new_block = *new_iter;
for (auto op_iter = block.begin(); op_iter != block.end(); ++op_iter) {
new_block.push_back((*op_iter).Clone(ir_mapping, clone_options));
}
}
}
// after all operation results are mapped, clone operands and regions.
{
auto iter = blocks_.begin();
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
auto new_iter = other.begin();
for (; iter != blocks_.end(); ++iter, ++new_iter) {
auto op_iter = (*iter)->begin();
auto new_op_iter = (*new_iter).begin();
for (; op_iter != (*iter)->end(); ++op_iter, ++new_op_iter) {
const Operation &op = *op_iter;
Operation &new_op = *new_op_iter;
// operands of new operation are same as source, now map them.
for (uint32_t i = 0; i < new_op.num_operands(); ++i)
new_op.operand(i).set_source(
ir_mapping.Lookup(new_op.operand_source(i)));
// clone sub regions
for (uint32_t i = 0; i < op.num_regions(); ++i)
op.region(i).CloneInto(new_op.region(i), ir_mapping);
}
}
}
}

std::unique_ptr<pir::Block> Region::TakeBack() {
Block *block = nullptr;
if (!blocks_.empty()) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>

#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/core/ir_mapping.h"
#include "paddle/pir/core/iterator.h"
#include "paddle/pir/core/visitors.h"

Expand Down Expand Up @@ -71,6 +72,9 @@ class IR_API Region {
template <WalkOrder Order = WalkOrder::PostOrder, typename FuncT>
void Walk(FuncT &&callback);

// clone this region into another region, target region will be overwritten.
void CloneInto(Region &other, IrMapping &ir_mapping) const; // NOLINT

// take the last block of region.
// if region is empty, return nullptr;
std::unique_ptr<Block> TakeBack();
Expand Down
60 changes: 57 additions & 3 deletions test/cpp/pir/core/ir_region_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/ir_mapping.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/core/utils.h"

TEST(region, erase_op_test) {
// (1) Init environment.
pir::IrContext* ctx = pir::IrContext::Instance();
pir::IrContext *ctx = pir::IrContext::Instance();

// (2) Create an empty program object
pir::Program program(ctx);
Expand All @@ -43,15 +44,68 @@ TEST(region, erase_op_test) {
builder.Build<pir::CombineOp>(std::vector<pir::Value>{a, b});

// Test pir::Block::erase
pir::Block* block = program.block();
pir::Block *block = program.block();
EXPECT_EQ(block->size(), 3u);
block->erase(block->back());
EXPECT_EQ(block->size(), 2u);

// Test pir::Region::erase
pir::Region& region = program.module_op()->region(0);
pir::Region &region = program.module_op()->region(0);
region.push_back(new pir::Block());
EXPECT_EQ(region.size(), 2u);
region.erase(region.begin());
EXPECT_EQ(region.size(), 1u);
}

TEST(region, clone_op_test) {
// (1) Init environment.
pir::IrContext *ctx = pir::IrContext::Instance();

// (2) Create an empty program object
pir::Program program(ctx);
pir::Builder builder = pir::Builder(ctx, program.block());

// (3) Def a = ConstantOp("2.0"); b = ConstantOp("2.0");
pir::FloatAttribute fp_attr = builder.float_attr(2.0f);
pir::Float32Type fp32_type = builder.float32_type();
pir::OpResult a =
builder.Build<pir::ConstantOp>(fp_attr, fp32_type)->result(0);
pir::OpResult b =
builder.Build<pir::ConstantOp>(fp_attr, fp32_type)->result(0);

// (6) Def c = CombineOp(a, b)
builder.Build<pir::CombineOp>(std::vector<pir::Value>{a, b});

// (7) Test clone module op
pir::Operation &op = *program.module_op();
pir::Block &block = op.region(0).front();
pir::IrMapping mapper;
pir::Operation &new_op = *op.Clone(mapper, pir::CloneOptions(true, true));

// (8) Check the cloned op recursively
EXPECT_EQ(new_op.num_regions(), 1u);
pir::Region &new_region = new_op.region(0);
EXPECT_EQ(new_region.size(), 1u);
pir::Block &new_block = new_region.front();
EXPECT_EQ(new_block.size(), 3u);

for (auto op_iter = block.begin(), new_op_iter = new_block.begin();
op_iter != block.end();
++op_iter, ++new_op_iter) {
pir::Operation &op = *op_iter;
pir::Operation &new_op = *new_op_iter;
EXPECT_EQ(op.num_operands(), new_op.num_operands());
for (uint32_t i = 0; i < op.num_operands(); ++i) {
EXPECT_EQ(mapper.Lookup(op.operand_source(i)), new_op.operand_source(i));
}
EXPECT_EQ(op.num_results(), new_op.num_results());
for (uint32_t i = 0; i < op.num_results(); ++i) {
EXPECT_EQ(mapper.Lookup(op.result(i)), new_op.result(i));
}
EXPECT_TRUE(std::equal(op.attributes().begin(),
op.attributes().end(),
new_op.attributes().begin(),
new_op.attributes().end()));
EXPECT_EQ(op.info(), new_op.info());
}
}