Skip to content

Commit bcefbaf

Browse files
authored
[PIR][DynamicShape] Polish some codes (#60651)
att, polish some codes
1 parent f177fa6 commit bcefbaf

File tree

11 files changed

+19
-736
lines changed

11 files changed

+19
-736
lines changed

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,11 @@ void PdOp2CinnOpConverter(::pir::Program *program) {
673673

674674
pm.Run(program);
675675
}
676+
677+
std::unique_ptr<pir::Pass> CreatePdOpToCinnOpPass() {
678+
return std::make_unique<PdOpToCinnOpPass>();
679+
}
680+
676681
} // namespace ir
677682
} // namespace dialect
678683
} // namespace cinn

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,12 @@ class PdOpToCinnOpPass : public pir::PatternRewritePass {
3131
bool CanApplyOn(pir::Operation *op) const override;
3232
};
3333

34+
// TODO(lanxianghit): delete this and use CreatePdOpToCinnOpPass() in
35+
// corresponding unit tests.
3436
void PdOp2CinnOpConverter(::pir::Program *program);
3537

38+
IR_API std::unique_ptr<pir::Pass> CreatePdOpToCinnOpPass();
39+
3640
} // namespace ir
3741
} // namespace dialect
3842
} // namespace cinn

paddle/fluid/pir/transforms/shape_optimization_pass.cc

Lines changed: 0 additions & 297 deletions
Original file line numberDiff line numberDiff line change
@@ -29,306 +29,9 @@
2929
namespace pir {
3030
namespace {
3131

32-
bool InsertTieShapeOnValue(pir::Value value,
33-
pir::Builder& builder) { // NOLINT
34-
// Insert TieShapeOp only for non-zero ranked tensor type.
35-
auto type = value.type().dyn_cast<DenseTensorType>();
36-
if (!type || type.dims().size() == 0) return true;
37-
38-
std::vector<pir::Value> dim_sizes;
39-
for (int64_t dim = 0, rank = type.dims().size(); dim < rank; ++dim) {
40-
auto dim_op = builder.Build<shape::TensorDimOp>(value, dim);
41-
dim_sizes.push_back(dim_op.out());
42-
}
43-
builder.Build<shape::TieShapeOp>(value, dim_sizes);
44-
return true;
45-
}
46-
47-
// Forward declaration
48-
bool InsertTieShapeOnRegion(pir::Region* region);
49-
50-
bool InsertTieShapeOnOperation(pir::Operation* op,
51-
pir::Builder& builder) { // NOLINT
52-
// TODO(zhangbopd): skip more specialized Ops.
53-
if (op->isa<shape::TieShapeOp>() || op->isa<shape::FuncOp>()) return true;
54-
55-
for (size_t i = 0; i < op->num_regions(); ++i) {
56-
if (!InsertTieShapeOnRegion(&(op->region(i)))) return false;
57-
}
58-
builder.SetInsertionPointAfter(op);
59-
for (pir::OpResult v : op->results()) {
60-
if (!InsertTieShapeOnValue(v, builder)) return false;
61-
}
62-
63-
return true;
64-
}
65-
66-
bool InsertTieShapeOnBlock(pir::Block* block) {
67-
pir::Builder builder =
68-
pir::Builder(pir::IrContext::Instance(), block, block->begin());
69-
// TODO(zhangbopd): mapping block arguments
70-
71-
std::vector<pir::Operation*> op_list;
72-
for (auto& op : *block) op_list.push_back(&op);
73-
for (pir::Operation* op : op_list) {
74-
if (!InsertTieShapeOnOperation(op, builder)) return false;
75-
}
76-
return true;
77-
}
78-
79-
bool InsertTieShapeOnRegion(pir::Region* region) {
80-
for (auto& block : *region) {
81-
if (!InsertTieShapeOnBlock(&block)) return false;
82-
}
83-
return true;
84-
}
85-
86-
// Convert:
87-
// %shape = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
88-
// To:
89-
// %d0 = tensor.dim %0, %c0 : tensor<?x?xf32>
90-
// %d1 = tensor.dim %0, %c1 : tensor<?x?xf32>
91-
// %shape = tensor.from_elements %d0, %d1 : tensor<2xindex>
92-
struct ExpandShapeOfOpPattern : public OpRewritePattern<shape::ShapeOfOp> {
93-
using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
94-
95-
bool MatchAndRewrite(shape::ShapeOfOp op,
96-
PatternRewriter& rewriter) const override {
97-
VLOG(3) << "Apply ExpandShapeOfOpPattern...";
98-
99-
auto type = op.out().type().dyn_cast<pir::DenseTensorType>();
100-
101-
if (!type || !type.dyn_cast<ShapedTypeInterface>().HasStaticShape() ||
102-
!type.dyn_cast<ShapedTypeInterface>().GetElementType().IsIndex())
103-
return false;
104-
105-
std::vector<Value> dim_sizes;
106-
for (int dim = 0,
107-
rank = type.dyn_cast<ShapedTypeInterface>().GetDyShape()[0];
108-
dim < rank;
109-
++dim) {
110-
dim_sizes.push_back(
111-
rewriter.Build<shape::TensorDimOp>(op.input(), dim).out());
112-
}
113-
rewriter.ReplaceOpWithNewOp<shape::FromElementsOp>(op, dim_sizes);
114-
return true;
115-
}
116-
};
117-
118-
// Fold dim of an operation that implements the InferSymbolicShapeInterface
119-
template <typename OpTy>
120-
struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern<OpTy> {
121-
using OpRewritePattern<OpTy>::OpRewritePattern;
122-
123-
bool MatchAndRewrite(OpTy dim_op, PatternRewriter& rewriter) const override {
124-
return true;
125-
}
126-
};
127-
12832
using PassPipelineRunner =
12933
std::function<bool(pir::PassManager&, pir::ModuleOp)>;
13034

131-
// Returns true if the type is possible to be a shape tensor type.
132-
// Shape tensor type :
133-
// - rank-1 static-shaped tensor type
134-
// - element type of the tensor is int or index
135-
// - number of elements of the tensor < 32, supposing that the
136-
// higiest possible rank is smaller than 32.
137-
bool IsCandidateShapeTensorType(Type type) {
138-
auto tensor_type = type.dyn_cast<DenseTensorType>();
139-
auto shaped_type = tensor_type.dyn_cast<ShapedTypeInterface>();
140-
141-
return (tensor_type && tensor_type && shaped_type.GetRank() == 1 &&
142-
shaped_type.HasStaticShape() &&
143-
shaped_type.GetElementType().IsIntOrIndex() &&
144-
shaped_type.GetDyShape()[0] < 32);
145-
}
146-
147-
class ShapeComputationIRAnalysis {
148-
public:
149-
using func = std::function<bool(Operation* op)>;
150-
explicit ShapeComputationIRAnalysis(ModuleOp m,
151-
SymbolicDimMgr& mgr); // NOLINT
152-
bool Run();
153-
154-
private:
155-
bool RunOnRegion(Region* region, func fn);
156-
bool RunOnBlock(Block* block, func fn);
157-
bool RunOnOperation(Operation* op, func fn);
158-
159-
bool BuildShapeOnOperation(Operation* op);
160-
bool BuildShapeOnValue(Value value);
161-
162-
bool ApplyOpConstraint(Operation* op);
163-
bool ApplyIndexOpConstraint(Operation* op);
164-
bool ApplyTieShapeOpConstraint(Operation* op);
165-
166-
bool initialized_ = false;
167-
ModuleOp m_;
168-
SymbolicDimMgr& mgr_;
169-
170-
std::unordered_map<Value, SymbolicDimOp> value_to_sym_dim_;
171-
172-
// shape tensor is the 1D ranked tensor with int/index dtype.
173-
std::unordered_map<Value, std::vector<SymbolicDimOp>>
174-
shape_tensor_to_sym_dims_;
175-
176-
std::unordered_map<Value, std::vector<SymbolicDimOp>>
177-
dense_tensor_to_sym_dims_;
178-
};
179-
180-
ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m,
181-
SymbolicDimMgr& mgr)
182-
: m_(m), mgr_(mgr) {}
183-
184-
bool ShapeComputationIRAnalysis::Run() {
185-
// Make sure only run once.
186-
if (initialized_) return false;
187-
initialized_ = true;
188-
return true;
189-
}
190-
191-
bool ShapeComputationIRAnalysis::RunOnRegion(Region* region, func fn) {
192-
for (auto& block : *region) {
193-
if (!RunOnBlock(&block, fn)) return false;
194-
}
195-
return true;
196-
}
197-
198-
bool ShapeComputationIRAnalysis::RunOnBlock(Block* block, func fn) {
199-
// TODO(zhangbopd): mapping block arguments
200-
201-
std::vector<Operation*> op_list;
202-
for (auto& op : *block) op_list.push_back(&op);
203-
for (Operation* op : op_list) {
204-
if (!RunOnOperation(op, fn)) return false;
205-
}
206-
return true;
207-
}
208-
209-
bool ShapeComputationIRAnalysis::RunOnOperation(Operation* op, func fn) {
210-
for (size_t i = 0; i < op->num_regions(); ++i) {
211-
if (!RunOnRegion(&(op->region(i)), fn)) return false;
212-
}
213-
return fn(op);
214-
}
215-
216-
bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) {
217-
if (op->isa<shape::FuncOp>()) return true;
218-
if (op->isa<shape::TieShapeOp>()) {
219-
Value value = op->operand_source(0);
220-
std::vector<SymbolicDimOp> symbols;
221-
if (op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) {
222-
auto attrs =
223-
op->attribute<ArrayAttribute>(SymbolicDimOp::GetSymbolicDimAttrName())
224-
.AsVector();
225-
for (Attribute attr : attrs) {
226-
auto sym = mgr_.symbolTable().Lookup<SymbolicDimOp>(
227-
attr.dyn_cast<StrAttribute>().AsString());
228-
IR_ENFORCE(sym);
229-
SymbolicDimOp root = mgr_.GetRootSymbolicDim(sym);
230-
symbols.push_back(root);
231-
}
232-
} else {
233-
symbols = mgr_.CreateSymbolicDimsForRankedValue(value);
234-
std::vector<Attribute> attrs;
235-
for (SymbolicDimOp sym : symbols) {
236-
Attribute rootSymbol =
237-
StrAttribute::get(m_->ir_context(), sym.GetSymName());
238-
attrs.push_back(rootSymbol);
239-
}
240-
op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(),
241-
ArrayAttribute::get(m_->ir_context(), attrs));
242-
}
243-
dense_tensor_to_sym_dims_[value] = std::move(symbols);
244-
return true;
245-
}
246-
for (auto& result : op->results()) {
247-
if (!BuildShapeOnValue(result)) return false;
248-
}
249-
return true;
250-
}
251-
252-
bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) {
253-
Type type = value.type();
254-
if (type.IsIntOrIndex()) {
255-
SymbolicDimOp sym = mgr_.NewSymbolicDim();
256-
value_to_sym_dim_[value] = sym;
257-
} else if (IsCandidateShapeTensorType(type)) {
258-
auto shaped_type = type.dyn_cast<ShapedTypeInterface>();
259-
std::vector<SymbolicDimOp> symbols;
260-
for (size_t i = 0, d = shaped_type.GetDyShape()[0]; i < d; ++i)
261-
symbols.push_back(mgr_.NewSymbolicDim());
262-
shape_tensor_to_sym_dims_[value] = std::move(symbols);
263-
}
264-
return true;
265-
}
266-
267-
bool ShapeComputationIRAnalysis::ApplyOpConstraint(Operation* op) {
268-
IR_ENFORCE(ApplyIndexOpConstraint(op),
269-
"Fail to apply constraint for index op");
270-
IR_ENFORCE(ApplyTieShapeOpConstraint(op),
271-
"Fail to apply constraint for tie_shape op");
272-
273-
// TODO(zhangbopd): add more constraints
274-
return true;
275-
}
276-
277-
bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) {
278-
if (op->num_results() == 0) return true;
279-
280-
Type type = op->result(0).type();
281-
if (!type.IsIntOrIndex()) return true;
282-
283-
if (auto dim_op = op->dyn_cast<shape::TensorDimOp>()) {
284-
int64_t dim_index = dim_op.index()
285-
.dyn_cast<OpResult>()
286-
.owner()
287-
->attribute<Int64Attribute>("value")
288-
.data();
289-
value_to_sym_dim_[dim_op.out()].UpdateKnownNonNegative(true);
290-
if (!mgr_.MapSymbolicDimEqual(
291-
value_to_sym_dim_[dim_op.out()],
292-
dense_tensor_to_sym_dims_[dim_op.source()][dim_index])) {
293-
return false;
294-
}
295-
296-
} else if (auto const_op = op->dyn_cast<ConstantOp>()) {
297-
int64_t val = const_op.value().dyn_cast<Int64Attribute>().data();
298-
if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[op->result(0)],
299-
mgr_.NewConstantSymbolicDim(val))) {
300-
return false;
301-
}
302-
}
303-
// TODO(zhangbopd): add support for reifyInferShape. (e.g. mul/add)
304-
return true;
305-
}
306-
307-
bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) {
308-
if (auto tie_shape = op->dyn_cast<shape::TieShapeOp>()) {
309-
auto& value = dense_tensor_to_sym_dims_[op->operand_source(0)];
310-
for (size_t idx = 0; idx < tie_shape.dims().size(); ++idx) {
311-
if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[tie_shape.dims()[idx]],
312-
value[idx]))
313-
return false;
314-
mgr_.GetRootSymbolicDim(value[idx]).UpdateKnownNonNegative(true);
315-
}
316-
}
317-
return true;
318-
}
319-
320-
bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) {
321-
// TODO(zhangbopd): Do some Canonicalizer.
322-
pir::SymbolicDimMgr mgr(m);
323-
324-
ShapeComputationIRAnalysis analysis(m, mgr);
325-
if (!analysis.Run()) {
326-
return false;
327-
}
328-
329-
return true;
330-
}
331-
33235
void PrintProgram(pir::ModuleOp m, std::string mgs) {
33336
std::ostringstream print_stream;
33437
print_stream << "\n\n";

paddle/fluid/pybind/pir.cc

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,8 +1564,7 @@ void AddCinnPass(std::shared_ptr<PassManager> &pass_manager, // NOLINT
15641564
has_dynamic_shape ? std::make_shared<pir::ShapeConstraintIRAnalysis>(ctx)
15651565
: nullptr;
15661566

1567-
cinn::dialect::ir::PdOp2CinnOpConverter(&program);
1568-
1567+
pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());
15691568
pass_manager->AddPass(
15701569
std::make_unique<cinn::dialect::ir::AddBroadcastToElementwisePass>());
15711570
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
@@ -1583,8 +1582,7 @@ void AddCinnPass(std::shared_ptr<PassManager> &pass_manager, // NOLINT
15831582
}
15841583

15851584
void InferSymbolicShapePass(
1586-
std::shared_ptr<PassManager> &pass_manager, // NOLINT
1587-
Program &program) { // NOLINT
1585+
std::shared_ptr<PassManager> &pass_manager) { // NOLINT
15881586
if (FLAGS_pir_apply_shape_optimization_pass) {
15891587
pir::IrContext *ctx = pir::IrContext::Instance();
15901588
ctx->GetOrRegisterDialect<pir::shape::ShapeDialect>();

paddle/pir/dialect/shape/utils/shape_optimization_utils.cc

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,7 @@ bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT
4848
return false;
4949
}
5050

51-
SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) {
52-
for (auto& op : m.block()) {
53-
if (op.isa<shape::FuncOp>()) {
54-
symbol_table_ = SymbolTable(&op);
55-
return;
56-
}
57-
}
58-
Builder builder = Builder(m_.ir_context(), &m_.block(), m_.block().begin());
59-
shape::FuncOp func = builder.Build<shape::FuncOp>();
60-
symbol_table_ = SymbolTable(func);
61-
}
51+
SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) {}
6252

6353
bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs,
6454
const SymbolicDimProduct& rhs) {
@@ -176,9 +166,7 @@ const std::string SymbolicDimMgr::GetNextName() {
176166
}
177167

178168
SymbolicDimOp SymbolicDimMgr::NewSymbolicDim(const std::string& name) {
179-
auto func_op = symbol_table_.getOp()->dyn_cast<shape::FuncOp>();
180-
IR_ENFORCE(func_op);
181-
Builder builder = Builder(m_.ir_context(), func_op.block());
169+
Builder builder = Builder(m_.ir_context(), nullptr, Block::Iterator{}, false);
182170
// default settting dim != 0
183171
SymbolicDimOp symbol =
184172
builder.Build<SymbolicDimOp>(name.empty() ? GetNextName() : name,

0 commit comments

Comments
 (0)