|
29 | 29 | namespace pir {
|
30 | 30 | namespace {
|
31 | 31 |
|
32 |
| -bool InsertTieShapeOnValue(pir::Value value, |
33 |
| - pir::Builder& builder) { // NOLINT |
34 |
| - // Insert TieShapeOp only for non-zero ranked tensor type. |
35 |
| - auto type = value.type().dyn_cast<DenseTensorType>(); |
36 |
| - if (!type || type.dims().size() == 0) return true; |
37 |
| - |
38 |
| - std::vector<pir::Value> dim_sizes; |
39 |
| - for (int64_t dim = 0, rank = type.dims().size(); dim < rank; ++dim) { |
40 |
| - auto dim_op = builder.Build<shape::TensorDimOp>(value, dim); |
41 |
| - dim_sizes.push_back(dim_op.out()); |
42 |
| - } |
43 |
| - builder.Build<shape::TieShapeOp>(value, dim_sizes); |
44 |
| - return true; |
45 |
| -} |
46 |
| - |
47 |
| -// Forward declaration |
48 |
| -bool InsertTieShapeOnRegion(pir::Region* region); |
49 |
| - |
50 |
| -bool InsertTieShapeOnOperation(pir::Operation* op, |
51 |
| - pir::Builder& builder) { // NOLINT |
52 |
| - // TODO(zhangbopd): skip more specialized Ops. |
53 |
| - if (op->isa<shape::TieShapeOp>() || op->isa<shape::FuncOp>()) return true; |
54 |
| - |
55 |
| - for (size_t i = 0; i < op->num_regions(); ++i) { |
56 |
| - if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; |
57 |
| - } |
58 |
| - builder.SetInsertionPointAfter(op); |
59 |
| - for (pir::OpResult v : op->results()) { |
60 |
| - if (!InsertTieShapeOnValue(v, builder)) return false; |
61 |
| - } |
62 |
| - |
63 |
| - return true; |
64 |
| -} |
65 |
| - |
66 |
| -bool InsertTieShapeOnBlock(pir::Block* block) { |
67 |
| - pir::Builder builder = |
68 |
| - pir::Builder(pir::IrContext::Instance(), block, block->begin()); |
69 |
| - // TODO(zhangbopd): mapping block arguments |
70 |
| - |
71 |
| - std::vector<pir::Operation*> op_list; |
72 |
| - for (auto& op : *block) op_list.push_back(&op); |
73 |
| - for (pir::Operation* op : op_list) { |
74 |
| - if (!InsertTieShapeOnOperation(op, builder)) return false; |
75 |
| - } |
76 |
| - return true; |
77 |
| -} |
78 |
| - |
79 |
| -bool InsertTieShapeOnRegion(pir::Region* region) { |
80 |
| - for (auto& block : *region) { |
81 |
| - if (!InsertTieShapeOnBlock(&block)) return false; |
82 |
| - } |
83 |
| - return true; |
84 |
| -} |
85 |
| - |
86 |
| -// Convert: |
87 |
| -// %shape = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex> |
88 |
| -// To: |
89 |
| -// %d0 = tensor.dim %0, %c0 : tensor<?x?xf32> |
90 |
| -// %d1 = tensor.dim %0, %c1 : tensor<?x?xf32> |
91 |
| -// %shape = tensor.from_elements %d0, %d1 : tensor<2xindex> |
92 |
| -struct ExpandShapeOfOpPattern : public OpRewritePattern<shape::ShapeOfOp> { |
93 |
| - using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern; |
94 |
| - |
95 |
| - bool MatchAndRewrite(shape::ShapeOfOp op, |
96 |
| - PatternRewriter& rewriter) const override { |
97 |
| - VLOG(3) << "Apply ExpandShapeOfOpPattern..."; |
98 |
| - |
99 |
| - auto type = op.out().type().dyn_cast<pir::DenseTensorType>(); |
100 |
| - |
101 |
| - if (!type || !type.dyn_cast<ShapedTypeInterface>().HasStaticShape() || |
102 |
| - !type.dyn_cast<ShapedTypeInterface>().GetElementType().IsIndex()) |
103 |
| - return false; |
104 |
| - |
105 |
| - std::vector<Value> dim_sizes; |
106 |
| - for (int dim = 0, |
107 |
| - rank = type.dyn_cast<ShapedTypeInterface>().GetDyShape()[0]; |
108 |
| - dim < rank; |
109 |
| - ++dim) { |
110 |
| - dim_sizes.push_back( |
111 |
| - rewriter.Build<shape::TensorDimOp>(op.input(), dim).out()); |
112 |
| - } |
113 |
| - rewriter.ReplaceOpWithNewOp<shape::FromElementsOp>(op, dim_sizes); |
114 |
| - return true; |
115 |
| - } |
116 |
| -}; |
117 |
| - |
118 |
| -// Fold dim of an operation that implements the InferSymbolicShapeInterface |
119 |
| -template <typename OpTy> |
120 |
| -struct DimOfShapedTypeOpInterfacePattern : public OpRewritePattern<OpTy> { |
121 |
| - using OpRewritePattern<OpTy>::OpRewritePattern; |
122 |
| - |
123 |
| - bool MatchAndRewrite(OpTy dim_op, PatternRewriter& rewriter) const override { |
124 |
| - return true; |
125 |
| - } |
126 |
| -}; |
127 |
| - |
128 | 32 | using PassPipelineRunner =
|
129 | 33 | std::function<bool(pir::PassManager&, pir::ModuleOp)>;
|
130 | 34 |
|
131 |
| -// Returns true if the type is possible to be a shape tensor type. |
132 |
| -// Shape tensor type : |
133 |
| -// - rank-1 static-shaped tensor type |
134 |
| -// - element type of the tensor is int or index |
135 |
| -// - number of elements of the tensor < 32, supposing that the |
136 |
| -// higiest possible rank is smaller than 32. |
137 |
| -bool IsCandidateShapeTensorType(Type type) { |
138 |
| - auto tensor_type = type.dyn_cast<DenseTensorType>(); |
139 |
| - auto shaped_type = tensor_type.dyn_cast<ShapedTypeInterface>(); |
140 |
| - |
141 |
| - return (tensor_type && tensor_type && shaped_type.GetRank() == 1 && |
142 |
| - shaped_type.HasStaticShape() && |
143 |
| - shaped_type.GetElementType().IsIntOrIndex() && |
144 |
| - shaped_type.GetDyShape()[0] < 32); |
145 |
| -} |
146 |
| - |
147 |
| -class ShapeComputationIRAnalysis { |
148 |
| - public: |
149 |
| - using func = std::function<bool(Operation* op)>; |
150 |
| - explicit ShapeComputationIRAnalysis(ModuleOp m, |
151 |
| - SymbolicDimMgr& mgr); // NOLINT |
152 |
| - bool Run(); |
153 |
| - |
154 |
| - private: |
155 |
| - bool RunOnRegion(Region* region, func fn); |
156 |
| - bool RunOnBlock(Block* block, func fn); |
157 |
| - bool RunOnOperation(Operation* op, func fn); |
158 |
| - |
159 |
| - bool BuildShapeOnOperation(Operation* op); |
160 |
| - bool BuildShapeOnValue(Value value); |
161 |
| - |
162 |
| - bool ApplyOpConstraint(Operation* op); |
163 |
| - bool ApplyIndexOpConstraint(Operation* op); |
164 |
| - bool ApplyTieShapeOpConstraint(Operation* op); |
165 |
| - |
166 |
| - bool initialized_ = false; |
167 |
| - ModuleOp m_; |
168 |
| - SymbolicDimMgr& mgr_; |
169 |
| - |
170 |
| - std::unordered_map<Value, SymbolicDimOp> value_to_sym_dim_; |
171 |
| - |
172 |
| - // shape tensor is the 1D ranked tensor with int/index dtype. |
173 |
| - std::unordered_map<Value, std::vector<SymbolicDimOp>> |
174 |
| - shape_tensor_to_sym_dims_; |
175 |
| - |
176 |
| - std::unordered_map<Value, std::vector<SymbolicDimOp>> |
177 |
| - dense_tensor_to_sym_dims_; |
178 |
| -}; |
179 |
| - |
180 |
| -ShapeComputationIRAnalysis::ShapeComputationIRAnalysis(ModuleOp m, |
181 |
| - SymbolicDimMgr& mgr) |
182 |
| - : m_(m), mgr_(mgr) {} |
183 |
| - |
184 |
| -bool ShapeComputationIRAnalysis::Run() { |
185 |
| - // Make sure only run once. |
186 |
| - if (initialized_) return false; |
187 |
| - initialized_ = true; |
188 |
| - return true; |
189 |
| -} |
190 |
| - |
191 |
| -bool ShapeComputationIRAnalysis::RunOnRegion(Region* region, func fn) { |
192 |
| - for (auto& block : *region) { |
193 |
| - if (!RunOnBlock(&block, fn)) return false; |
194 |
| - } |
195 |
| - return true; |
196 |
| -} |
197 |
| - |
198 |
| -bool ShapeComputationIRAnalysis::RunOnBlock(Block* block, func fn) { |
199 |
| - // TODO(zhangbopd): mapping block arguments |
200 |
| - |
201 |
| - std::vector<Operation*> op_list; |
202 |
| - for (auto& op : *block) op_list.push_back(&op); |
203 |
| - for (Operation* op : op_list) { |
204 |
| - if (!RunOnOperation(op, fn)) return false; |
205 |
| - } |
206 |
| - return true; |
207 |
| -} |
208 |
| - |
209 |
| -bool ShapeComputationIRAnalysis::RunOnOperation(Operation* op, func fn) { |
210 |
| - for (size_t i = 0; i < op->num_regions(); ++i) { |
211 |
| - if (!RunOnRegion(&(op->region(i)), fn)) return false; |
212 |
| - } |
213 |
| - return fn(op); |
214 |
| -} |
215 |
| - |
216 |
| -bool ShapeComputationIRAnalysis::BuildShapeOnOperation(Operation* op) { |
217 |
| - if (op->isa<shape::FuncOp>()) return true; |
218 |
| - if (op->isa<shape::TieShapeOp>()) { |
219 |
| - Value value = op->operand_source(0); |
220 |
| - std::vector<SymbolicDimOp> symbols; |
221 |
| - if (op->HasAttribute(SymbolicDimOp::GetSymbolicDimAttrName())) { |
222 |
| - auto attrs = |
223 |
| - op->attribute<ArrayAttribute>(SymbolicDimOp::GetSymbolicDimAttrName()) |
224 |
| - .AsVector(); |
225 |
| - for (Attribute attr : attrs) { |
226 |
| - auto sym = mgr_.symbolTable().Lookup<SymbolicDimOp>( |
227 |
| - attr.dyn_cast<StrAttribute>().AsString()); |
228 |
| - IR_ENFORCE(sym); |
229 |
| - SymbolicDimOp root = mgr_.GetRootSymbolicDim(sym); |
230 |
| - symbols.push_back(root); |
231 |
| - } |
232 |
| - } else { |
233 |
| - symbols = mgr_.CreateSymbolicDimsForRankedValue(value); |
234 |
| - std::vector<Attribute> attrs; |
235 |
| - for (SymbolicDimOp sym : symbols) { |
236 |
| - Attribute rootSymbol = |
237 |
| - StrAttribute::get(m_->ir_context(), sym.GetSymName()); |
238 |
| - attrs.push_back(rootSymbol); |
239 |
| - } |
240 |
| - op->set_attribute(SymbolicDimOp::GetSymbolicDimAttrName(), |
241 |
| - ArrayAttribute::get(m_->ir_context(), attrs)); |
242 |
| - } |
243 |
| - dense_tensor_to_sym_dims_[value] = std::move(symbols); |
244 |
| - return true; |
245 |
| - } |
246 |
| - for (auto& result : op->results()) { |
247 |
| - if (!BuildShapeOnValue(result)) return false; |
248 |
| - } |
249 |
| - return true; |
250 |
| -} |
251 |
| - |
252 |
| -bool ShapeComputationIRAnalysis::BuildShapeOnValue(Value value) { |
253 |
| - Type type = value.type(); |
254 |
| - if (type.IsIntOrIndex()) { |
255 |
| - SymbolicDimOp sym = mgr_.NewSymbolicDim(); |
256 |
| - value_to_sym_dim_[value] = sym; |
257 |
| - } else if (IsCandidateShapeTensorType(type)) { |
258 |
| - auto shaped_type = type.dyn_cast<ShapedTypeInterface>(); |
259 |
| - std::vector<SymbolicDimOp> symbols; |
260 |
| - for (size_t i = 0, d = shaped_type.GetDyShape()[0]; i < d; ++i) |
261 |
| - symbols.push_back(mgr_.NewSymbolicDim()); |
262 |
| - shape_tensor_to_sym_dims_[value] = std::move(symbols); |
263 |
| - } |
264 |
| - return true; |
265 |
| -} |
266 |
| - |
267 |
| -bool ShapeComputationIRAnalysis::ApplyOpConstraint(Operation* op) { |
268 |
| - IR_ENFORCE(ApplyIndexOpConstraint(op), |
269 |
| - "Fail to apply constraint for index op"); |
270 |
| - IR_ENFORCE(ApplyTieShapeOpConstraint(op), |
271 |
| - "Fail to apply constraint for tie_shape op"); |
272 |
| - |
273 |
| - // TODO(zhangbopd): add more constraints |
274 |
| - return true; |
275 |
| -} |
276 |
| - |
277 |
| -bool ShapeComputationIRAnalysis::ApplyIndexOpConstraint(Operation* op) { |
278 |
| - if (op->num_results() == 0) return true; |
279 |
| - |
280 |
| - Type type = op->result(0).type(); |
281 |
| - if (!type.IsIntOrIndex()) return true; |
282 |
| - |
283 |
| - if (auto dim_op = op->dyn_cast<shape::TensorDimOp>()) { |
284 |
| - int64_t dim_index = dim_op.index() |
285 |
| - .dyn_cast<OpResult>() |
286 |
| - .owner() |
287 |
| - ->attribute<Int64Attribute>("value") |
288 |
| - .data(); |
289 |
| - value_to_sym_dim_[dim_op.out()].UpdateKnownNonNegative(true); |
290 |
| - if (!mgr_.MapSymbolicDimEqual( |
291 |
| - value_to_sym_dim_[dim_op.out()], |
292 |
| - dense_tensor_to_sym_dims_[dim_op.source()][dim_index])) { |
293 |
| - return false; |
294 |
| - } |
295 |
| - |
296 |
| - } else if (auto const_op = op->dyn_cast<ConstantOp>()) { |
297 |
| - int64_t val = const_op.value().dyn_cast<Int64Attribute>().data(); |
298 |
| - if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[op->result(0)], |
299 |
| - mgr_.NewConstantSymbolicDim(val))) { |
300 |
| - return false; |
301 |
| - } |
302 |
| - } |
303 |
| - // TODO(zhangbopd): add support for reifyInferShape. (e.g. mul/add) |
304 |
| - return true; |
305 |
| -} |
306 |
| - |
307 |
| -bool ShapeComputationIRAnalysis::ApplyTieShapeOpConstraint(Operation* op) { |
308 |
| - if (auto tie_shape = op->dyn_cast<shape::TieShapeOp>()) { |
309 |
| - auto& value = dense_tensor_to_sym_dims_[op->operand_source(0)]; |
310 |
| - for (size_t idx = 0; idx < tie_shape.dims().size(); ++idx) { |
311 |
| - if (!mgr_.MapSymbolicDimEqual(value_to_sym_dim_[tie_shape.dims()[idx]], |
312 |
| - value[idx])) |
313 |
| - return false; |
314 |
| - mgr_.GetRootSymbolicDim(value[idx]).UpdateKnownNonNegative(true); |
315 |
| - } |
316 |
| - } |
317 |
| - return true; |
318 |
| -} |
319 |
| - |
320 |
| -bool OptimizeShapeComputation(pir::ModuleOp m, PassPipelineRunner runner) { |
321 |
| - // TODO(zhangbopd): Do some Canonicalizer. |
322 |
| - pir::SymbolicDimMgr mgr(m); |
323 |
| - |
324 |
| - ShapeComputationIRAnalysis analysis(m, mgr); |
325 |
| - if (!analysis.Run()) { |
326 |
| - return false; |
327 |
| - } |
328 |
| - |
329 |
| - return true; |
330 |
| -} |
331 |
| - |
332 | 35 | void PrintProgram(pir::ModuleOp m, std::string mgs) {
|
333 | 36 | std::ostringstream print_stream;
|
334 | 37 | print_stream << "\n\n";
|
|
0 commit comments