Skip to content

Commit 1b26966

Browse files
authored
[DimExpr] Convert Broadcast to BroadcastTree (#60440)
* backup BroadcastTree * add SubstituteDimExpr * add helper function ConstructBroadcastTree * Fix compile error * Code format * Polish DimExprUtilTest * Add cmake file * Change namesapce * Fix compile error * Fix unittest * reconstruct BroadcastTree * Polish DimExprUtilTest * Reconstruct BroadcastTree * Finish BroadcastBranch * Finish BroadcastBranch * Finish BroadcastBranch * Add Unittest * Remove unnecessary dim_expr_util * Add header file
1 parent e397b29 commit 1b26966

File tree

4 files changed

+450
-0
lines changed

4 files changed

+450
-0
lines changed

paddle/cinn/common/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ gather_srcs(
2424
integer_set.cc
2525
dim_expr_simplify.cc
2626
dim_expr_converter.cc
27+
broadcast_tree.cc
2728
dim_expr_util.cc)
2829

2930
cinn_cc_test(test_equation_graph_topo_walker SRCS
@@ -54,4 +55,5 @@ if(NOT CINN_ONLY)
5455
cinncore)
5556
cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS
5657
cinncore)
58+
cinn_cc_test(broadcast_tree_test SRCS broadcast_tree_test.cc DEPS cinncore)
5759
endif()

paddle/cinn/common/broadcast_tree.cc

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/cinn/common/broadcast_tree.h"
16+
17+
#include <optional>
18+
#include <unordered_map>
19+
20+
#include "paddle/cinn/common/dim_expr_simplify.h"
21+
#include "paddle/cinn/common/dim_expr_util.h"
22+
23+
namespace cinn::common {
24+
25+
namespace {
26+
27+
template <typename DoEachT>
28+
bool SearchBroadcast(const symbol::DimExpr& dim_expr, const DoEachT& DoEach);
29+
30+
template <typename DoEachT>
31+
bool SearchBroadcastImpl(int64_t, const DoEachT& DoEach) {
32+
return false;
33+
}
34+
35+
template <typename DoEachT>
36+
bool SearchBroadcastImpl(const std::string&, const DoEachT& DoEach) {
37+
return false;
38+
}
39+
40+
template <typename T, typename DoEachT>
41+
bool SearchBroadcastImplForUnary(const T& unary, const DoEachT& DoEach) {
42+
const auto& operand = unary->data;
43+
return SearchBroadcast(operand, DoEach);
44+
}
45+
46+
template <typename DoEachT>
47+
bool SearchBroadcastImpl(const symbol::Negative<symbol::DimExpr>& unary,
48+
const DoEachT& DoEach) {
49+
return SearchBroadcastImplForUnary(unary, DoEach);
50+
}
51+
52+
template <typename DoEachT>
53+
bool SearchBroadcastImpl(const symbol::Reciprocal<symbol::DimExpr>& unary,
54+
const DoEachT& DoEach) {
55+
return SearchBroadcastImplForUnary(unary, DoEach);
56+
}
57+
58+
template <typename T, typename DoEachT>
59+
bool SearchBroadcastImplForVariadic(const T& variadic, const DoEachT& DoEach) {
60+
const auto& operands = *(variadic.operands);
61+
for (const auto& operand : operands) {
62+
if (SearchBroadcast(operand, DoEach)) return true;
63+
}
64+
return false;
65+
}
66+
67+
template <typename DoEachT>
68+
bool SearchBroadcastImpl(const symbol::Add<symbol::DimExpr>& variadic,
69+
const DoEachT& DoEach) {
70+
return SearchBroadcastImplForVariadic(variadic, DoEach);
71+
}
72+
73+
template <typename DoEachT>
74+
bool SearchBroadcastImpl(const symbol::Mul<symbol::DimExpr>& variadic,
75+
const DoEachT& DoEach) {
76+
return SearchBroadcastImplForVariadic(variadic, DoEach);
77+
}
78+
79+
template <typename DoEachT>
80+
bool SearchBroadcastImpl(const symbol::Max<symbol::DimExpr>& variadic,
81+
const DoEachT& DoEach) {
82+
return SearchBroadcastImplForVariadic(variadic, DoEach);
83+
}
84+
85+
template <typename DoEachT>
86+
bool SearchBroadcastImpl(const symbol::Min<symbol::DimExpr>& variadic,
87+
const DoEachT& DoEach) {
88+
return SearchBroadcastImplForVariadic(variadic, DoEach);
89+
}
90+
91+
template <typename DoEachT>
92+
bool SearchBroadcastImpl(const symbol::Broadcast<symbol::DimExpr>& variadic,
93+
const DoEachT& DoEach) {
94+
const auto& operands = *(variadic.operands);
95+
for (const auto& operand : operands) {
96+
CHECK(!operand.isa<int64_t>());
97+
if (SearchBroadcast(operand, DoEach)) return true;
98+
}
99+
return DoEach(variadic);
100+
}
101+
102+
template <typename DoEachT>
103+
bool SearchBroadcast(const symbol::DimExpr& dim_expr, const DoEachT& DoEach) {
104+
return std::visit(
105+
[&](const auto& impl) { return SearchBroadcastImpl(impl, DoEach); },
106+
dim_expr.variant());
107+
}
108+
109+
template <typename DoEachT>
110+
void ForEachBroadcastDimExpr(const BroadcastLeaf& leaves,
111+
const DoEachT& DoEach) {
112+
for (const auto& dim_exprs : *leaves) {
113+
for (const auto& dim_expr : dim_exprs) {
114+
if (SearchBroadcast(dim_expr, DoEach)) return;
115+
}
116+
}
117+
}
118+
119+
std::optional<symbol::Broadcastable<symbol::DimExpr>> GetFirstCstrBroadcastable(
120+
const BroadcastLeaf& leaves) {
121+
std::optional<symbol::Broadcastable<symbol::DimExpr>> ret;
122+
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
123+
const auto& operands = broadcast.operands;
124+
std::optional<symbol::DimExpr> lhs_symbol;
125+
std::optional<symbol::DimExpr> rhs_symbol;
126+
size_t i = 0;
127+
for (; i < operands->size(); ++i) {
128+
if (operands->at(i).template isa<std::string>()) {
129+
lhs_symbol = operands->at(i);
130+
break;
131+
}
132+
}
133+
for (i++; i < operands->size(); ++i) {
134+
if (operands->at(i).template isa<std::string>()) {
135+
rhs_symbol = operands->at(i);
136+
break;
137+
}
138+
}
139+
if (lhs_symbol.has_value() && rhs_symbol.has_value()) {
140+
CHECK(lhs_symbol != rhs_symbol);
141+
ret = symbol::Broadcastable<symbol::DimExpr>{lhs_symbol.value(),
142+
rhs_symbol.value()};
143+
return true;
144+
}
145+
return false;
146+
});
147+
if (ret.has_value()) return ret.value();
148+
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
149+
const auto& operands = broadcast.operands;
150+
std::optional<symbol::DimExpr> lhs_symbol;
151+
std::optional<symbol::DimExpr> rhs;
152+
for (const auto& operand : *operands) {
153+
if (operand.template isa<std::string>()) {
154+
lhs_symbol = operand;
155+
break;
156+
}
157+
}
158+
for (const auto& operand : *operands) {
159+
if (operand != lhs_symbol) {
160+
rhs = operand;
161+
break;
162+
}
163+
}
164+
if (lhs_symbol.has_value() && rhs.has_value()) {
165+
ret = symbol::Broadcastable<symbol::DimExpr>{lhs_symbol.value(),
166+
rhs.value()};
167+
return true;
168+
}
169+
return false;
170+
});
171+
if (ret.has_value()) return ret.value();
172+
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
173+
const auto& operands = broadcast.operands;
174+
CHECK_GE(operands->size(), 2);
175+
CHECK(operands->at(0) != operands->at(1));
176+
ret = symbol::Broadcastable<symbol::DimExpr>{operands->at(0),
177+
operands->at(1)};
178+
return true;
179+
});
180+
return ret;
181+
}
182+
183+
using Pattern2Placement = std::unordered_map<symbol::DimExpr, symbol::DimExpr>;
184+
185+
Pattern2Placement ConstructCstrLhsEqRhsReplacement(
186+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition) {
187+
auto [lhs, rhs] = *broadcastable_condition;
188+
if (lhs.isa<std::string>()) return Pattern2Placement{{lhs, rhs}};
189+
if (rhs.isa<std::string>()) return Pattern2Placement{{rhs, lhs}};
190+
return Pattern2Placement{{lhs, rhs}};
191+
}
192+
193+
Pattern2Placement ConstructCstrLhsEqOneReplacement(
194+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition) {
195+
const auto& [lhs, rhs] = *broadcastable_condition;
196+
return Pattern2Placement{{lhs, symbol::DimExpr{1}}};
197+
}
198+
199+
Pattern2Placement ConstructCstrRhsEqOneReplacement(
200+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition) {
201+
const auto& [lhs, rhs] = *broadcastable_condition;
202+
return Pattern2Placement{{rhs, symbol::DimExpr{1}}};
203+
}
204+
205+
symbol::DimExpr GetCstrLhsEqRhsDimExpr(
206+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
207+
const symbol::DimExpr& dim_expr) {
208+
const auto& pattern2replacement =
209+
ConstructCstrLhsEqRhsReplacement(broadcastable_condition);
210+
return SimplifyDimExpr(SubstituteDimExpr(dim_expr, pattern2replacement));
211+
}
212+
213+
symbol::DimExpr GetCstrLhsEqOneDimExpr(
214+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
215+
const symbol::DimExpr& dim_expr) {
216+
const auto& pattern2replacement =
217+
ConstructCstrLhsEqOneReplacement(broadcastable_condition);
218+
return SimplifyDimExpr(SubstituteDimExpr(dim_expr, pattern2replacement));
219+
}
220+
221+
symbol::DimExpr GetCstrRhsEqOneDimExpr(
222+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
223+
const symbol::DimExpr& dim_expr) {
224+
const auto& pattern2replacement =
225+
ConstructCstrRhsEqOneReplacement(broadcastable_condition);
226+
return SimplifyDimExpr(SubstituteDimExpr(dim_expr, pattern2replacement));
227+
}
228+
229+
typedef symbol::DimExpr (*ConvertDimExprT)(
230+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
231+
const symbol::DimExpr& dim_expr);
232+
233+
template <ConvertDimExprT ConvertDimExpr>
234+
BroadcastLeaf ConvertBroadcastLeaf(
235+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
236+
const BroadcastLeaf& leaves) {
237+
BroadcastLeaf ret{};
238+
for (const auto& dim_exprs : *leaves) {
239+
std::vector<symbol::DimExpr> converted{};
240+
converted.reserve(dim_exprs.size());
241+
for (const auto& dim_expr : dim_exprs) {
242+
converted.push_back(ConvertDimExpr(broadcastable_condition, dim_expr));
243+
}
244+
ret->emplace_back(std::move(converted));
245+
}
246+
return ret;
247+
}
248+
249+
BroadcastLeaf GetCstrLhsEqRhsLeaves(
250+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
251+
const BroadcastLeaf& leaves) {
252+
return ConvertBroadcastLeaf<&GetCstrLhsEqRhsDimExpr>(broadcastable_condition,
253+
leaves);
254+
}
255+
256+
BroadcastLeaf GetCstrLhsEqOneLeaves(
257+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
258+
const BroadcastLeaf& leaves) {
259+
return ConvertBroadcastLeaf<&GetCstrLhsEqOneDimExpr>(broadcastable_condition,
260+
leaves);
261+
}
262+
263+
BroadcastLeaf GetCstrRhsEqOneLeaves(
264+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
265+
const BroadcastLeaf& leaves) {
266+
return ConvertBroadcastLeaf<&GetCstrRhsEqOneDimExpr>(broadcastable_condition,
267+
leaves);
268+
}
269+
270+
BroadcastBranch<BroadcastTree> ConstructBroadcastBranch(
271+
const symbol::Broadcastable<symbol::DimExpr>& broadcastable_condition,
272+
const BroadcastLeaf& leaves) {
273+
BroadcastLeaf cstr_lhs_eq_rhs_leaves =
274+
GetCstrLhsEqRhsLeaves(broadcastable_condition, leaves);
275+
BroadcastLeaf cstr_lhs_eq_one_leaves =
276+
GetCstrLhsEqOneLeaves(broadcastable_condition, leaves);
277+
BroadcastLeaf cstr_rhs_eq_one_leaves =
278+
GetCstrRhsEqOneLeaves(broadcastable_condition, leaves);
279+
// clang-format off
280+
return BroadcastBranch<BroadcastTree>{
281+
/*broadcastable_condition*/ broadcastable_condition,
282+
/*cstr_lhs_eq_rhs_branch*/ ConstructBroadcastTree(cstr_lhs_eq_rhs_leaves),
283+
/*cstr_lhs_eq_one_branch*/ ConstructBroadcastTree(cstr_lhs_eq_one_leaves),
284+
/*cstr_rhs_eq_one_branch*/ ConstructBroadcastTree(cstr_rhs_eq_one_leaves)
285+
};
286+
// clang-format on
287+
}
288+
289+
} // namespace
290+
291+
BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves) {
292+
std::optional<symbol::Broadcastable<symbol::DimExpr>>
293+
broadcastable_condition = GetFirstCstrBroadcastable(leaves);
294+
if (!broadcastable_condition.has_value()) return leaves;
295+
return ConstructBroadcastBranch(broadcastable_condition.value(), leaves);
296+
}
297+
298+
} // namespace cinn::common

paddle/cinn/common/broadcast_tree.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/cinn/adt/tree.h"
18+
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
19+
20+
namespace cinn::common {
21+
22+
template <typename T>
23+
using BroadcastBranch = adt::Tuple<symbol::Broadcastable<symbol::DimExpr>,
24+
/*cstr_lhs_eq_rhs_branch*/ T,
25+
/*cstr_lhs_eq_one_branch*/ T,
26+
/*cstr_rhs_eq_one_branch*/ T>;
27+
28+
using BroadcastLeaf = adt::List<std::vector<symbol::DimExpr>>;
29+
30+
using BroadcastTree = adt::Tree<BroadcastBranch, BroadcastLeaf>;
31+
32+
BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves);
33+
34+
} // namespace cinn::common

0 commit comments

Comments
 (0)