Skip to content

Commit d32eb2d

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into hack6_13
2 parents d51cddd + 70f2b54 commit d32eb2d

28 files changed

+1680
-128
lines changed

paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
107107

108108
if (x_dims != y_dims) {
109109
auto output_shape = GetOutputShape(x_dims, y_dims);
110+
pir::ShapeConstraintIRAnalysis& shape_analysis =
111+
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
112+
std::vector<symbol::DimExpr> out_dim;
113+
out_dim.reserve(output_shape.size());
114+
for (auto d : output_shape) {
115+
out_dim.emplace_back(d);
116+
}
117+
110118
if (!IsSameDim(x_dims, output_shape)) {
111119
// add broadcast to input 0
112120
if (auto full_op = op->operand_source(0)
@@ -122,13 +130,18 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
122130
.dyn_cast<paddle::dialect::PlaceAttribute>()
123131
.data());
124132
op->operand(0).set_source(new_full->result(0));
133+
shape_analysis.SetShapeOrDataForValue(
134+
new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim));
125135
} else {
126136
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
127137
op->operand_source(0),
128138
cinn::hlir::framework::pir::GetBroadcastAxis(x_dims, output_shape),
129139
output_shape);
130140

131141
op->operand(0).set_source(new_transpose_op->result(0));
142+
shape_analysis.SetShapeOrDataForValue(
143+
new_transpose_op.result(0),
144+
symbol::TensorShapeOrDataDimExprs(out_dim));
132145
}
133146
}
134147

@@ -147,13 +160,18 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
147160
.data());
148161

149162
op->operand(1).set_source(new_full->result(0));
163+
shape_analysis.SetShapeOrDataForValue(
164+
new_full.result(0), symbol::TensorShapeOrDataDimExprs(out_dim));
150165
} else {
151166
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
152167
op->operand_source(1),
153168
cinn::hlir::framework::pir::GetBroadcastAxis(y_dims, output_shape),
154169
output_shape);
155170

156171
op->operand(1).set_source(new_transpose_op->result(0));
172+
shape_analysis.SetShapeOrDataForValue(
173+
new_transpose_op.result(0),
174+
symbol::TensorShapeOrDataDimExprs(out_dim));
157175
}
158176
}
159177

paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h"
3232
#include "paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.h"
3333
#include "paddle/cinn/hlir/dialect/operator/transforms/fold_manipulation_ops_pass.h"
34+
#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h"
3435
#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h"
3536
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.h"
3637
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.h"
@@ -80,6 +81,7 @@ void ApplyPdToCinnPass(
8081
const std::function<std::shared_ptr<::pir::PassManager>()>&
8182
CreatePassManager) {
8283
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
84+
pass_manager->AddPass(cinn::dialect::ir::CreateFuseParallelMatmulPass());
8385
pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());
8486
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
8587
pass_manager->Run(program);
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/hlir/dialect/operator/transforms/fuse_parallel_matmul_pass.h"
16+
17+
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
18+
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
19+
#include "paddle/cinn/hlir/framework/pir/utils.h"
20+
#include "paddle/common/ddim.h"
21+
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
22+
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
23+
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
24+
#include "paddle/pir/include/core/builtin_dialect.h"
25+
#include "paddle/pir/include/pass/pass.h"
26+
#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h"
27+
#include "paddle/pir/include/pattern_rewrite/pattern_applicator.h"
28+
#include "paddle/pir/include/pattern_rewrite/pattern_match.h"
29+
#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h"
30+
31+
namespace cinn {
32+
namespace dialect {
33+
namespace ir {
34+
35+
class MergeParallelMatmulPattern
36+
: public pir::OpRewritePattern<paddle::dialect::MatmulOp> {
37+
public:
38+
using pir::OpRewritePattern<paddle::dialect::MatmulOp>::OpRewritePattern;
39+
40+
bool MatchAndRewrite(paddle::dialect::MatmulOp matmul_op,
41+
pir::PatternRewriter& rewriter) const override {
42+
auto ValidMatmulTranspose = [&](pir::Operation* op) -> bool {
43+
if (!op->dyn_cast<paddle::dialect::MatmulOp>()) {
44+
return false;
45+
}
46+
bool trans_x =
47+
op->attribute("transpose_x").dyn_cast<pir::BoolAttribute>().data();
48+
bool trans_y =
49+
op->attribute("transpose_y").dyn_cast<pir::BoolAttribute>().data();
50+
return !trans_x && !trans_y;
51+
};
52+
if (!ValidMatmulTranspose(matmul_op)) {
53+
return false;
54+
}
55+
56+
auto VectorPrefixEqual = [](const std::vector<std::int64_t>& a,
57+
const std::vector<std::int64_t>& b) {
58+
if (a.size() != b.size()) {
59+
return false;
60+
}
61+
for (int i = 0; i < a.size() - 1; ++i) {
62+
if (a[i] != b[i]) {
63+
return false;
64+
}
65+
}
66+
return true;
67+
};
68+
69+
auto input_x = matmul_op.operand_source(0);
70+
const std::vector<pir::Operation*> merge_ops = [&]() {
71+
std::vector<pir::Operation*> ret;
72+
std::optional<std::vector<std::int64_t>> pre_dim;
73+
std::vector<std::int64_t> cur_dim;
74+
for (auto it = input_x.use_begin(); it != input_x.use_end(); ++it) {
75+
if (!ValidMatmulTranspose(it->owner())) {
76+
continue;
77+
}
78+
if (!pre_dim.has_value()) {
79+
pre_dim = ::common::vectorize(
80+
it->owner()
81+
->operand_source(1)
82+
.type()
83+
.dyn_cast<paddle::dialect::DenseTensorType>()
84+
.dims());
85+
}
86+
cur_dim = ::common::vectorize(
87+
it->owner()
88+
->operand_source(1)
89+
.type()
90+
.dyn_cast<paddle::dialect::DenseTensorType>()
91+
.dims());
92+
if (VectorPrefixEqual(pre_dim.value(), cur_dim)) {
93+
ret.push_back(it->owner());
94+
}
95+
}
96+
return ret;
97+
}();
98+
if (merge_ops.size() <= 1) {
99+
return false;
100+
}
101+
102+
const std::vector<pir::Value> combine_ins = [&]() {
103+
std::vector<pir::Value> ret;
104+
for (pir::Operation* op : merge_ops) {
105+
ret.push_back(op->operand_source(1));
106+
}
107+
return ret;
108+
}();
109+
const std::vector<std::int64_t> combine_shapes = [&]() {
110+
std::vector<std::int64_t> ret{0};
111+
std::int64_t accumulate = 0;
112+
for (pir::Value input : combine_ins) {
113+
auto shape =
114+
input.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
115+
accumulate += shape[shape.size() - 1];
116+
ret.push_back(accumulate);
117+
}
118+
return ret;
119+
}();
120+
121+
auto combine_out = rewriter.Build<pir::CombineOp>(combine_ins).result(0);
122+
auto concat_out =
123+
rewriter.Build<paddle::dialect::ConcatOp>(combine_out, -1).result(0);
124+
auto matmul_out =
125+
rewriter.Build<paddle::dialect::MatmulOp>(input_x, concat_out)
126+
.result(0);
127+
128+
for (size_t i = 0; i < merge_ops.size(); ++i) {
129+
auto split_out =
130+
rewriter
131+
.Build<paddle::dialect::SliceOp>(
132+
matmul_out,
133+
std::vector<std::int64_t>{
134+
matmul_out.type()
135+
.dyn_cast<paddle::dialect::DenseTensorType>()
136+
.dims()
137+
.size() -
138+
1},
139+
std::vector<std::int64_t>{combine_shapes[i]},
140+
std::vector<int64_t>{combine_shapes[i + 1]},
141+
std::vector<std::int64_t>{},
142+
std::vector<std::int64_t>{})
143+
.result(0);
144+
145+
rewriter.ReplaceAllUsesWith(merge_ops[i]->result(0), split_out);
146+
rewriter.EraseOp(merge_ops[i]);
147+
}
148+
149+
return true;
150+
}
151+
};
152+
153+
class FuseParallelMatmulPass : public pir::PatternRewritePass {
154+
public:
155+
FuseParallelMatmulPass()
156+
: pir::PatternRewritePass("fuse_parallel_matmul_pass", 1) {}
157+
158+
pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override {
159+
pir::RewritePatternSet ps(context);
160+
ps.Add<MergeParallelMatmulPattern>(context);
161+
return ps;
162+
}
163+
};
164+
165+
std::unique_ptr<pir::Pass> CreateFuseParallelMatmulPass() {
166+
return std::make_unique<FuseParallelMatmulPass>();
167+
}
168+
169+
} // namespace ir
170+
} // namespace dialect
171+
} // namespace cinn
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include "paddle/pir/include/pass/pass.h"
19+
20+
namespace cinn {
21+
namespace dialect {
22+
namespace ir {
23+
24+
IR_API std::unique_ptr<pir::Pass> CreateFuseParallelMatmulPass();
25+
26+
} // namespace ir
27+
} // namespace dialect
28+
} // namespace cinn

paddle/cinn/hlir/pe/transform.cc

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "paddle/cinn/lang/builtin.h"
2828
#include "paddle/cinn/lang/compute.h"
2929
#include "paddle/cinn/utils/string.h"
30+
#include "paddle/common/errors.h"
3031

3132
namespace cinn {
3233
namespace hlir {
@@ -425,8 +426,9 @@ ir::Tensor Concat(const ir::Tensor& A,
425426
ir::Tensor Concat(const std::vector<ir::Tensor>& input_tensors,
426427
int axis,
427428
const std::string& name) {
429+
// input size 1 is valid for Concat
428430
int input_size = input_tensors.size();
429-
CHECK_GE(input_size, 2U) << "Concat should have at least 2 input tensors";
431+
CHECK_GE(input_size, 1U) << "Concat should have at least 1 input tensors";
430432
std::vector<Expr> output_shape = input_tensors[0]->shape;
431433
int input_dim = output_shape.size();
432434
CHECK(axis >= -input_dim && axis < input_dim)
@@ -1057,7 +1059,7 @@ ir::Tensor Transpose(const ir::Tensor& input,
10571059

10581060
ir::Tensor Slice(const ir::Tensor& A,
10591061
const std::vector<int>& starts,
1060-
const std::vector<int>& axes,
1062+
const std::vector<int>& const_axes,
10611063
const std::vector<int>& strides,
10621064
const std::vector<int>& decrease_axis,
10631065
const std::vector<Expr>& output_shape,
@@ -1066,6 +1068,21 @@ ir::Tensor Slice(const ir::Tensor& A,
10661068
for (const auto& shape : A->shape) {
10671069
input_shape.emplace_back(shape.as_int32());
10681070
}
1071+
std::vector<int> axes;
1072+
std::transform(const_axes.begin(),
1073+
const_axes.end(),
1074+
std::back_inserter(axes),
1075+
[rank = A->shape.size()](const int axis) -> int {
1076+
if (axis < 0) {
1077+
PADDLE_ENFORCE_GE(
1078+
axis + rank,
1079+
0,
1080+
::common::errors::InvalidArgument(
1081+
"The axis of slice is out of range"));
1082+
return axis + rank;
1083+
}
1084+
return axis;
1085+
});
10691086
std::vector<int> new_starts(starts);
10701087
for (int i = 0; i < axes.size(); i++) {
10711088
if (new_starts[i] < -input_shape[axes[i]]) {
@@ -1110,7 +1127,7 @@ ir::Tensor Slice(const ir::Tensor& A,
11101127

11111128
ir::Tensor SliceSymbolic(const ir::Tensor& A,
11121129
const std::vector<int>& starts,
1113-
const std::vector<int>& axes,
1130+
const std::vector<int>& const_axes,
11141131
const std::vector<int>& strides,
11151132
const std::vector<int>& decrease_axis,
11161133
const std::vector<Expr>& output_shape,
@@ -1125,6 +1142,21 @@ ir::Tensor SliceSymbolic(const ir::Tensor& A,
11251142
starts.end(),
11261143
std::back_inserter(new_starts),
11271144
[](const int start) { return ir::Expr(start); });
1145+
std::vector<int> axes;
1146+
std::transform(const_axes.begin(),
1147+
const_axes.end(),
1148+
std::back_inserter(axes),
1149+
[rank = A->shape.size()](const int axis) -> int {
1150+
if (axis < 0) {
1151+
PADDLE_ENFORCE_GE(
1152+
axis + rank,
1153+
0,
1154+
::common::errors::InvalidArgument(
1155+
"The axis of slice is out of range"));
1156+
return axis + rank;
1157+
}
1158+
return axis;
1159+
});
11281160

11291161
for (int i = 0; i < axes.size(); i++) {
11301162
if (input_shape[axes[i]].is_constant()) {

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,10 +629,12 @@ const std::vector<std::string> kPirMkldnnPasses{
629629
"matmul_transpose_reshape_fuse_pass",
630630
"matmul_elementwise_add_fuse_pass",
631631
"matmul_activation_fuse_pass",
632+
"softplus_activation_fuse_pass",
632633
"conv_elementwise_add_onednn_fuse_pass",
633634
"conv_activation_onednn_fuse_pass",
634635
"conv_concat_activation_onednn_fuse_pass",
635-
"elementwise_act_onednn_fuse_pass"};
636+
"elementwise_act_onednn_fuse_pass",
637+
"operator_unsqueeze_onednn_fuse_pass"};
636638

637639
const std::vector<std::string> kPirCpuPasses{};
638640

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,10 +204,20 @@ bool SparseWeightEmbeddingOpInferSymbolicShape(
204204

205205
bool ExpandAsOpInferSymbolicShape(
206206
pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) {
207-
PADDLE_THROW(phi::errors::Unimplemented(
208-
op->name() +
209-
" 's InferSymbolicShape interface is NOT implemented "
210-
"now because of the lack of necessary information."));
207+
std::vector<int> target_shape =
208+
paddle::dialect::details::GetVectorAttr<int>(op, "target_shape");
209+
const std::vector<symbol::DimExpr> &output_dims = [&] {
210+
std::vector<symbol::DimExpr> output_dims;
211+
output_dims.reserve(target_shape.size());
212+
for (int shape : target_shape) {
213+
output_dims.push_back(shape);
214+
}
215+
return output_dims;
216+
}();
217+
218+
shape_analysis->SetShapeOrDataForValue(
219+
op->result(0), symbol::TensorShapeOrDataDimExprs(output_dims));
220+
211221
return true;
212222
}
213223

0 commit comments

Comments
 (0)