Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Support Region Clone in Operation::Clone #60590

Merged
merged 10 commits into from
Jan 9, 2024
89 changes: 24 additions & 65 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,47 +898,6 @@ void BindInsertionPoint(pybind11::module *m) {
InsertionPoint class represents the insertion point in the Builder.)DOC");
}

Operation *BuildOpFrom(
Operation *to_copy_op,
std::unordered_map<pir::Value, pir::Value> &value_map) { // NOLINT
pir::OperationArgument to_create_argument(to_copy_op->info());
to_create_argument.attributes = to_copy_op->attributes();

VLOG(6) << "start copy op: " << to_copy_op->name();
auto origin_results = to_copy_op->results();
VLOG(6) << "start translate origin results into op type.";
std::transform(origin_results.begin(),
origin_results.end(),
std::back_inserter(to_create_argument.output_types),
[](const pir::OpResult &r) {
// OpResult -> OpType
return r.type();
});

// transform by value_map dict.
VLOG(6) << "start create op.";
auto origin_operands = to_copy_op->operands();
std::transform(origin_operands.begin(),
origin_operands.end(),
std::back_inserter(to_create_argument.inputs),
[&value_map](const pir::OpOperand &operand) {
// Operand -> OpResult
return value_map[operand.source()];
});
auto *cloned_op = Operation::Create(std::move(to_create_argument));

std::vector<int> tmp;
std::transform(origin_results.begin(),
origin_results.end(),
cloned_op->results().begin(),
std::back_inserter(tmp), // NOLINT, just a placeholder.
[&value_map](const OpResult &a, const OpResult &b) { // NOLINT
value_map[a.Value::impl()] = b.Value::impl();
return 1;
});
return cloned_op;
}

std::list<Operation *>::const_iterator list_offset(const Block *block,
int start_idx) {
auto it = block->begin();
Expand Down Expand Up @@ -1057,19 +1016,13 @@ static auto GetNoNeedBufferValue(const ::pir::Block *whole_block,
using OpResultMap =
std::pair<std::vector<pir::OpResult>, std::vector<pir::OpResult>>;
std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
const Program &program) {
Program &program) { // NOLINT
// Limitation of this function:
// 1. don't support Parameters.
// 2. don't support Regions in operator.
pir::IrContext *ctx = pir::IrContext::Instance();
auto cloned_program = std::make_shared<Program>(ctx);
std::unordered_map<pir::Value, pir::Value> value_map;
for (auto &op : *program.block()) {
auto *cloned_op = BuildOpFrom(&op, value_map);
cloned_program->block()->push_back(cloned_op);
}
pir::IrMapping mapper;
auto cloned_program = program.Clone(mapper);
std::vector<pir::OpResult> associated_array_key, associated_array_value;
for (auto &pair : value_map) {
for (auto &pair : mapper.Map<pir::Value>()) {
associated_array_key.push_back(pair.first.dyn_cast<pir::OpResult>());
associated_array_value.push_back(pair.second.dyn_cast<pir::OpResult>());
}
Expand Down Expand Up @@ -1178,21 +1131,26 @@ SplitedResult SplitForwardBackward(
std::unordered_set<pir::Value> backward_inputs;
std::tie(middle_values, backward_inputs) = AnalysisMiddleVariable(
program, forward_in_out_values, forward_range, backward_range);
std::unordered_map<pir::Value, pir::Value> forward_value_map;
std::unordered_map<pir::Value, pir::Value> backward_value_map;
pir::Builder backward_builder = pir::Builder(ctx, backward_program->block());
bool has_backward = (backward_range[1] > backward_range[0]);

// forward program construct.
VLOG(4) << "start create forward program.";
range_block_do(program.block(),
forward_range,
[&forward_value_map, &forward_program](Operation *op) {
auto *cloned_op = BuildOpFrom(op, forward_value_map);
forward_program->block()->push_back(cloned_op);
});
pir::IrMapping forward_mapper;
auto clone_options = pir::CloneOptions(true, true);
range_block_do(
program.block(),
forward_range,
[&forward_mapper, &forward_program, &clone_options](Operation *op) {
auto *cloned_op = op->Clone(forward_mapper, clone_options);
forward_program->block()->push_back(cloned_op);
});
auto &forward_value_map = forward_mapper.MutableMap<pir::Value>();

// backward program construc.
// Step1. insert data op for inputs_values and middle_values
pir::IrMapping backward_mapper;
auto &backward_value_map = backward_mapper.MutableMap<pir::Value>();
int counter = 0;
auto create_data_fn = [&backward_builder,
&backward_inputs,
Expand Down Expand Up @@ -1311,12 +1269,13 @@ SplitedResult SplitForwardBackward(

// Step2. copy backward ops .
VLOG(4) << "start copy backward ops";
range_block_do(program.block(),
backward_range,
[&backward_value_map, &backward_program](Operation *op) {
auto *cloned_op = BuildOpFrom(op, backward_value_map);
backward_program->block()->push_back(cloned_op);
});
range_block_do(
program.block(),
backward_range,
[&backward_mapper, &backward_program, &clone_options](Operation *op) {
auto *cloned_op = op->Clone(backward_mapper, clone_options);
backward_program->block()->push_back(cloned_op);
});
// counter = 0;
VLOG(4) << "start create backward outputs, inserting set_parameter ops.";
if (has_backward) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/block.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ class IR_API Block {
friend class Region;
void SetParent(Region *parent);

// Take out corresponding Operation and its ownershipe.
// Take out corresponding Operation and its ownership.
friend class Operation;
Operation *Take(Operation *op);

Expand Down
61 changes: 53 additions & 8 deletions paddle/pir/core/ir_mapping.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,69 @@
#pragma once
#include <unordered_map>
#include "paddle/common/enforce.h"
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/value.h"

namespace pir {
class Block;
class Operation;

class IrMapping {
public:
void Add(Value from, Value to) { value_map_[from] = to; }
template <typename T>
void Add(T from, T to) {
if (!from) return;
MutableMap<T>()[from] = to;
}

template <typename T>
T Lookup(T from) const {
if (!from) return static_cast<T>(nullptr);
IR_ENFORCE(Map<T>().count(from) > 0, "Not found key in IRMapping.");
return Map<T>().at(from);
}

template <typename T>
void Earse(T from) {
MutableMap<T>().erase(from);
}

Value Lookup(Value from) const {
IR_ENFORCE(value_map_.count(from) > 0, "Not Found Value in IRMapping.");
return value_map_.at(from);
void Clear() {
value_map_.clear();
block_map_.clear();
operation_map_.clear();
}
void Earse(Value from) { value_map_.erase(from); }

void Clear() { value_map_.clear(); }
template <typename T>
using MapType = std::unordered_map<T, T>;

template <typename T>
const MapType<T> &Map() const {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

template <typename T>
MapType<T> &MutableMap() {
if constexpr (std::is_convertible<T, Value>::value)
return value_map_;
else if constexpr (std::is_convertible<T, Block *>::value)
return block_map_;
else if constexpr (std::is_convertible<T, Operation *>::value)
return operation_map_;
else
IR_THROW("Not support type in IRMapping.");
}

private:
std::unordered_map<Value, Value> value_map_;
MapType<Value> value_map_;
MapType<Block *> block_map_;
MapType<Operation *> operation_map_;
};

} // namespace pir
15 changes: 12 additions & 3 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,6 @@ Operation *Operation::Create(const std::vector<Value> &inputs,
}

Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
IR_ENFORCE(!options.IsCloneRegions() || num_regions_ <= 0,
"Operation CloneRegions is unimplemented currently.");
IR_ENFORCE(num_successors_ == 0,
"Operation::Clone is not unimplemented for multiple successors.");

Expand All @@ -156,10 +154,21 @@ Operation *Operation::Clone(IrMapping &ir_mapping, CloneOptions options) {
output_types.push_back(result.type());
}
auto *new_op = Create(inputs, attributes_, output_types, info_, num_regions_);
ir_mapping.Add(this, new_op);

// record outputs mapping info
for (uint32_t i = 0; i < num_results_; ++i) {
ir_mapping.Add(result(i), new_op->result(i));
ir_mapping.Add(static_cast<Value>(result(i)),
static_cast<Value>(new_op->result(i)));
}

if (options.IsCloneRegions()) {
// clone regions recursively
for (uint32_t i = 0; i < num_regions_; ++i) {
this->region(i).CloneInto(new_op->region(i), ir_mapping);
}
}

return new_op;
}

Expand Down
11 changes: 11 additions & 0 deletions paddle/pir/core/program.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ Program::~Program() {
}
}

std::shared_ptr<Program> Program::Clone(IrMapping& ir_mapping) {
pir::IrContext* ctx = pir::IrContext::Instance();
auto new_program = std::make_shared<Program>(ctx);
auto clone_options = CloneOptions(true, true);
for (auto& op : *block()) {
auto* new_op = op.Clone(ir_mapping, clone_options);
new_program->block()->push_back(new_op);
}
return new_program;
}

Parameter* Program::GetParameter(const std::string& name) const {
if (parameters_.count(name) != 0) {
return parameters_.at(name).get();
Expand Down
3 changes: 3 additions & 0 deletions paddle/pir/core/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/ir_mapping.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/parameter.h"

Expand Down Expand Up @@ -54,6 +55,8 @@ class IR_API Program {

static std::unique_ptr<Program> Parse(std::istream& is, IrContext* ctx);

std::shared_ptr<Program> Clone(IrMapping& ir_mapping); // NOLINT

Block* block() { return &module_.block(); }
const Block* block() const { return &module_op().block(); }

Expand Down
40 changes: 40 additions & 0 deletions paddle/pir/core/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,46 @@ Region::Iterator Region::erase(ConstIterator position) {
return blocks_.erase(position);
}

void Region::CloneInto(Region &other, IrMapping &ir_mapping) {
if (empty()) {
return;
}
other.clear();
auto clone_options = CloneOptions(false, false);
// clone blocks, block arguments and sub operations
for (auto &block : *this) {
auto new_block = new Block;
ir_mapping.Add(&block, new_block);
for (auto &arg : block.args()) {
ir_mapping.Add(arg, new_block->AddArgument(arg.type()));
}
other.push_back(new_block);
// clone sub operations, but not map operands nor clone regions
for (auto op_iter = block.begin(); op_iter != block.end(); ++op_iter) {
new_block->push_back(op_iter->Clone(ir_mapping, clone_options));
}
}
// after all operation results are mapped, map operands and clone regions.
{
auto iter = begin();
auto new_iter = other.begin();
for (; iter != end(); ++iter, ++new_iter) {
auto op_iter = iter->begin();
auto new_op_iter = new_iter->begin();
for (; op_iter != iter->end(); ++op_iter, ++new_op_iter) {
Operation &op = *op_iter;
Operation &new_op = *new_op_iter;
// operands of new_op are same as op, now map them.
for (uint32_t i = 0; i < op.num_operands(); ++i)
new_op.operand(i).set_source(ir_mapping.Lookup(op.operand_source(i)));
// clone sub regions
for (uint32_t i = 0; i < op.num_regions(); ++i)
op.region(i).CloneInto(new_op.region(i), ir_mapping);
}
}
}
}

std::unique_ptr<pir::Block> Region::TakeBack() {
Block *block = nullptr;
if (!blocks_.empty()) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/pir/core/region.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <memory>

#include "paddle/pir/core/dll_decl.h"
#include "paddle/pir/core/ir_mapping.h"
#include "paddle/pir/core/iterator.h"
#include "paddle/pir/core/visitors.h"

Expand Down Expand Up @@ -71,6 +72,9 @@ class IR_API Region {
template <WalkOrder Order = WalkOrder::PostOrder, typename FuncT>
void Walk(FuncT &&callback);

// clone this region into another region, target region will be overwritten.
void CloneInto(Region &other, IrMapping &ir_mapping); // NOLINT

// take the last block of region.
// if region is empty, return nullptr;
std::unique_ptr<Block> TakeBack();
Expand Down
Loading