Skip to content

Commit 8de8736

Browse files
[CINN]add ShapeOrDataDimExpr represent for TensorArray type (#65956)
* [CINN]add ShapeOrDataDimExpr represent for TensorArray type * change name * fix some print info
1 parent 953aeea commit 8de8736

File tree

10 files changed

+124
-32
lines changed

10 files changed

+124
-32
lines changed

paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ struct ShapeSignatureGenerator {
188188
}
189189
},
190190
[&](const symbol::TensorListShapeOrDataDimExprs& impl) { return; },
191+
[&](const symbol::RankedTensorArrayShapeOrDataDimExprs& impl) {
192+
return;
193+
},
191194
[&](const symbol::NullShapeOrDataDimExpr& impl) { return; });
192195
};
193196

@@ -249,6 +252,8 @@ struct ShapeSignatureGenerator {
249252
[&](const symbol::TensorListShapeOrDataDimExprs& impl) -> ResType {
250253
return std::make_pair(std::nullopt, std::nullopt);
251254
},
255+
[&](const symbol::RankedTensorArrayShapeOrDataDimExprs& impl)
256+
-> ResType { return std::make_pair(std::nullopt, std::nullopt); },
252257
[&](const symbol::NullShapeOrDataDimExpr& impl) -> ResType {
253258
return std::make_pair(std::nullopt, std::nullopt);
254259
});

paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ std::vector<pir::Value> FindSourceDenseTensorOfDimTensor(
7070
[](const symbol::TensorListShapeOrDataDimExprs& dim_expr) {
7171
return true;
7272
},
73+
[](const symbol::RankedTensorArrayShapeOrDataDimExprs& dim_expr) {
74+
return false;
75+
},
7376
[](const symbol::NullShapeOrDataDimExpr& null_shape_or_data) {
7477
return false;
7578
}};

paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,24 @@ void VisitEachValue(const pir::Operation& op, const DoEachT& DoEach) {
5151
}
5252
}
5353

54+
std::vector<symbol::DimExpr> SimplifyDimExprVector(
55+
const std::vector<symbol::DimExpr>& original_dim_exprs) {
56+
std::vector<symbol::DimExpr> simplified_dim_exprs{};
57+
for (const symbol::DimExpr& dim_expr : original_dim_exprs) {
58+
simplified_dim_exprs.push_back(symbol::SimplifyDimExpr(dim_expr));
59+
}
60+
return simplified_dim_exprs;
61+
}
62+
5463
symbol::TensorShapeOrDataDimExprs SimplifyTensorShapeOrData(
5564
const symbol::TensorShapeOrDataDimExprs& shape_or_data) {
56-
const auto& SimplifyDimExpr =
57-
[](const std::vector<symbol::DimExpr>& original_dim_expr)
58-
-> std::vector<symbol::DimExpr> {
59-
std::vector<symbol::DimExpr> simplified_dim_expr{};
60-
for (const symbol::DimExpr& dim_expr : original_dim_expr) {
61-
simplified_dim_expr.push_back(symbol::SimplifyDimExpr(dim_expr));
62-
}
63-
return simplified_dim_expr;
64-
};
65-
6665
std::vector<symbol::DimExpr> simplified_shape =
67-
SimplifyDimExpr(shape_or_data.shape());
66+
SimplifyDimExprVector(shape_or_data.shape());
6867
if (!shape_or_data.data().has_value()) {
6968
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape);
7069
}
7170
std::vector<symbol::DimExpr> simplified_data =
72-
SimplifyDimExpr(shape_or_data.data().value());
71+
SimplifyDimExprVector(shape_or_data.data().value());
7372
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape,
7473
simplified_data);
7574
}
@@ -90,7 +89,12 @@ symbol::ShapeOrDataDimExprs SimplifyShapeOrData(
9089
}
9190
return symbol::ShapeOrDataDimExprs(simplified_tensor_list);
9291
},
93-
[&](const symbol::NullShapeOrDataDimExpr& null_shape_or_data) {
92+
[](const symbol::RankedTensorArrayShapeOrDataDimExprs& tensor_array) {
93+
return symbol::ShapeOrDataDimExprs(
94+
symbol::RankedTensorArrayShapeOrDataDimExprs(
95+
SimplifyDimExprVector(tensor_array.GetShapeHint())));
96+
},
97+
[](const symbol::NullShapeOrDataDimExpr& null_shape_or_data) {
9498
return symbol::ShapeOrDataDimExprs(null_shape_or_data);
9599
}};
96100
return std::visit(lambdas, shape_or_data.variant());

paddle/cinn/hlir/dialect/operator/transforms/local_infer_symbolic_util.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,21 @@ void InitLocalShapeAnalysis(const pir::Operation& op,
7575
}
7676
return symbol::ShapeOrDataDimExprs(ret);
7777
};
78+
auto NewSymbolReplacedTensorArray =
79+
[&](const symbol::RankedTensorArrayShapeOrDataDimExprs&
80+
tensor_array_shape) {
81+
return symbol::ShapeOrDataDimExprs(
82+
symbol::RankedTensorArrayShapeOrDataDimExprs(
83+
NewSymbolReplacedDimExprs(tensor_array_shape.GetShapeHint())));
84+
};
7885
auto NewSymbolReplacedNull =
7986
[&](const symbol::NullShapeOrDataDimExpr& null_shape_or_data) {
8087
return symbol::ShapeOrDataDimExprs(null_shape_or_data);
8188
};
8289
auto GetNewSymbolReplaced = [&](const auto& value_dim_exprs) {
8390
auto patterns = common::Overloaded{NewSymbolReplacedTensor,
8491
NewSymbolReplacedTensorList,
92+
NewSymbolReplacedTensorArray,
8593
NewSymbolReplacedNull};
8694
return std::visit(patterns, value_dim_exprs.variant());
8795
};

paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ void VisitEachDimExpr(const symbol::ShapeOrDataDimExprs& shape_or_data,
7373
VisitEachDimExprFromTensorShapeOrData(tensor_shape_or_data, DoEach);
7474
}
7575
},
76+
[&](const symbol::RankedTensorArrayShapeOrDataDimExprs& tensor_array) {
77+
PADDLE_THROW(phi::errors::Fatal(
78+
"Dead code, TensorArray should not be handled in backend."));
79+
for (const symbol::DimExpr& dim_expr : tensor_array.GetShapeHint()) {
80+
DoEach(dim_expr);
81+
}
82+
return;
83+
},
7684
[&](const symbol::NullShapeOrDataDimExpr& null_shape_or_data) {
7785
return;
7886
}};

paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,10 @@ struct PirToPyCodeConverterHelper {
671671
[](const symbol::TensorListShapeOrDataDimExprs& impl) {
672672
return ConvertTensorListShapeOrData(impl);
673673
},
674+
[](const symbol::RankedTensorArrayShapeOrDataDimExprs& impl) {
675+
// TODO(Hongqing-work): support tensor_array to py
676+
return std::string("self.s_tensor_array()");
677+
},
674678
[](const symbol::NullShapeOrDataDimExpr& impl) {
675679
return std::string("self.s_null()");
676680
});

paddle/fluid/pir/dialect/operator/utils/shape_analysis_utils.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ symbol::ShapeOrDataDimExprs ClearDataInfo(
4545
}
4646
return symbol::ShapeOrDataDimExprs{new_shape_exprs};
4747
},
48+
[](const symbol::RankedTensorArrayShapeOrDataDimExprs& shape_exprs) {
49+
return symbol::ShapeOrDataDimExprs{
50+
symbol::RankedTensorArrayShapeOrDataDimExprs{
51+
shape_exprs.GetShapeHint()}};
52+
},
4853
[](const symbol::NullShapeOrDataDimExpr& null_shape_or_data) {
4954
return symbol::ShapeOrDataDimExprs{null_shape_or_data};
5055
}};

paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,41 @@ class ShapeOrData {
114114
std::optional<std::vector<T>> data_;
115115
};
116116

117+
using NullShapeOrDataDimExpr = std::monostate;
117118
using TensorShapeOrDataDimExprs = ShapeOrData<DimExpr>;
118119
using TensorListShapeOrDataDimExprs = std::vector<TensorShapeOrDataDimExprs>;
119-
using NullShapeOrDataDimExpr = std::monostate;
120-
using ShapeOrDataDimExprsBase = std::variant<NullShapeOrDataDimExpr,
121-
TensorShapeOrDataDimExprs,
122-
TensorListShapeOrDataDimExprs>;
120+
121+
/* TensorArray can append tensors dynamically. In a static graph, we only
122+
* store the shape of one element as a hint, because we assume that all elements
123+
* in the TensorArray have the same rank, and with equal constraints on specific
124+
* dimensions. */
125+
class RankedTensorArrayShapeOrDataDimExprs {
126+
public:
127+
RankedTensorArrayShapeOrDataDimExprs() = default;
128+
explicit RankedTensorArrayShapeOrDataDimExprs(
129+
const std::vector<DimExpr>& shape)
130+
: shape_hint_{shape} {}
131+
const std::vector<DimExpr>& GetShapeHint() const { return shape_hint_; }
132+
bool operator==(const RankedTensorArrayShapeOrDataDimExprs& other) const {
133+
if (shape_hint_.size() != other.shape_hint_.size()) return false;
134+
for (size_t i = 0; i < shape_hint_.size(); ++i) {
135+
DimExpr dim0 = symbol::SimplifyDimExpr(shape_hint_[i]);
136+
DimExpr dim1 = symbol::SimplifyDimExpr(other.shape_hint_[i]);
137+
if (dim0 != dim1) return false;
138+
}
139+
140+
return true;
141+
}
142+
143+
private:
144+
std::vector<DimExpr> shape_hint_;
145+
};
146+
147+
using ShapeOrDataDimExprsBase =
148+
std::variant<NullShapeOrDataDimExpr,
149+
TensorShapeOrDataDimExprs,
150+
TensorListShapeOrDataDimExprs,
151+
RankedTensorArrayShapeOrDataDimExprs>;
123152

124153
class ShapeOrDataDimExprs : public ShapeOrDataDimExprsBase {
125154
public:
@@ -131,6 +160,10 @@ class ShapeOrDataDimExprs : public ShapeOrDataDimExprsBase {
131160
const TensorListShapeOrDataDimExprs& tensor_list_dim_exprs)
132161
: ShapeOrDataDimExprsBase(tensor_list_dim_exprs) {}
133162

163+
ShapeOrDataDimExprs(const RankedTensorArrayShapeOrDataDimExprs&
164+
tensor_array_dim_exprs) // NOLINT
165+
: ShapeOrDataDimExprsBase(tensor_array_dim_exprs) {}
166+
134167
ShapeOrDataDimExprs(const NullShapeOrDataDimExpr& null_dim_expr) // NOLINT
135168
: ShapeOrDataDimExprsBase(null_dim_expr) {}
136169

@@ -227,6 +260,14 @@ struct hash<symbol::TensorListShapeOrDataDimExprs> {
227260
}
228261
};
229262

263+
template <>
264+
struct hash<symbol::RankedTensorArrayShapeOrDataDimExprs> {
265+
std::size_t operator()(
266+
const symbol::RankedTensorArrayShapeOrDataDimExprs& obj) const {
267+
return std::hash<std::vector<symbol::DimExpr>>()(obj.GetShapeHint());
268+
}
269+
};
270+
230271
template <>
231272
struct hash<symbol::ShapeOrDataDimExprs> {
232273
std::size_t operator()(const symbol::ShapeOrDataDimExprs& obj) const {

paddle/pir/src/dialect/shape/utils/shape_analysis.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,11 @@ InferSymbolicShapeContext::SimplifyBroadcastForShapeOrData(
273273
}
274274
return symbol::ShapeOrDataDimExprs(simplified_tensor_list);
275275
},
276+
[&](const symbol::RankedTensorArrayShapeOrDataDimExprs& tensor_array) {
277+
symbol::RankedTensorArrayShapeOrDataDimExprs simplified_tensor_array(
278+
DimExprsVisitor(tensor_array.GetShapeHint()));
279+
return symbol::ShapeOrDataDimExprs(simplified_tensor_array);
280+
},
276281
[&](const symbol::NullShapeOrDataDimExpr& null_shape_or_data) {
277282
return symbol::ShapeOrDataDimExprs(null_shape_or_data);
278283
});

paddle/pir/src/dialect/shape/utils/shape_or_data_expr.cc

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,28 +15,27 @@
1515
#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h"
1616

1717
namespace symbol {
18+
std::vector<DimExpr> SubstituteDimExprVector(
19+
const std::vector<DimExpr>& original_dim_expr,
20+
const std::unordered_map<DimExpr, DimExpr>& substitution_pattern) {
21+
std::vector<DimExpr> substituted_dim_expr{};
22+
for (const DimExpr& dim_expr : original_dim_expr) {
23+
const auto& tmp_dim_expr =
24+
SubstituteDimExpr(dim_expr, substitution_pattern);
25+
substituted_dim_expr.push_back(SimplifyDimExpr(tmp_dim_expr));
26+
}
27+
return substituted_dim_expr;
28+
}
29+
1830
TensorShapeOrDataDimExprs SubstituteTensorShapeOrData(
1931
const TensorShapeOrDataDimExprs& shape_or_data,
2032
const std::unordered_map<DimExpr, DimExpr>& substitution_pattern) {
21-
auto SubstituteOneDimExpr =
22-
[](const std::vector<DimExpr>& original_dim_expr,
23-
const std::unordered_map<DimExpr, DimExpr>& substitution_pattern)
24-
-> std::vector<DimExpr> {
25-
std::vector<DimExpr> substituted_dim_expr{};
26-
for (const DimExpr& dim_expr : original_dim_expr) {
27-
const auto& tmp_dim_expr =
28-
SubstituteDimExpr(dim_expr, substitution_pattern);
29-
substituted_dim_expr.push_back(SimplifyDimExpr(tmp_dim_expr));
30-
}
31-
return substituted_dim_expr;
32-
};
33-
3433
std::vector<DimExpr> substituted_shape =
35-
SubstituteOneDimExpr(shape_or_data.shape(), substitution_pattern);
34+
SubstituteDimExprVector(shape_or_data.shape(), substitution_pattern);
3635
if (!shape_or_data.data().has_value()) {
3736
return ShapeOrData<DimExpr>(substituted_shape);
3837
} else {
39-
std::vector<DimExpr> substituted_data = SubstituteOneDimExpr(
38+
std::vector<DimExpr> substituted_data = SubstituteDimExprVector(
4039
shape_or_data.data().value(), substitution_pattern);
4140
return ShapeOrData<DimExpr>(substituted_shape, substituted_data);
4241
}
@@ -59,6 +58,12 @@ ShapeOrDataDimExprs SubstituteShapeOrData(
5958
}
6059
return ShapeOrDataDimExprs(substituted_tensor_list);
6160
},
61+
[&](const RankedTensorArrayShapeOrDataDimExprs& tensor_array) {
62+
RankedTensorArrayShapeOrDataDimExprs substituted_tensor_array(
63+
SubstituteDimExprVector(tensor_array.GetShapeHint(),
64+
substitution_pattern));
65+
return ShapeOrDataDimExprs(substituted_tensor_array);
66+
},
6267
[&](const NullShapeOrDataDimExpr& null_shape_or_data) {
6368
return ShapeOrDataDimExprs(null_shape_or_data);
6469
}};
@@ -89,6 +94,10 @@ std::ostream& operator<<(std::ostream& stream,
8994
}
9095
}
9196
},
97+
[&](const RankedTensorArrayShapeOrDataDimExprs& tensor_array_shape_data) {
98+
stream << "TensorArray with shape hint: "
99+
<< tensor_array_shape_data.GetShapeHint();
100+
},
92101
[&](const NullShapeOrDataDimExpr& null_shape_data) {
93102
stream << "shape[NULL], data[NULL]";
94103
}};

0 commit comments

Comments
 (0)