Skip to content

[CINN]Apply broadcast device lowering for too many broadcast tree bugs #66207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions paddle/cinn/common/broadcast_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
#include <unordered_map>

#include "paddle/common/enforce.h"
#include "paddle/common/flags.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h"

COMMON_DECLARE_int64(pir_broadcast_tree_limit);

namespace cinn::common {

namespace {
Expand Down Expand Up @@ -95,9 +92,6 @@ template <typename DoEachT>
bool SearchBroadcastImpl(const symbol::Broadcast<symbol::DimExpr>& variadic,
const DoEachT& DoEach) {
const auto& operands = *(variadic.operands);
if (operands.size() > 3) {
PADDLE_THROW(phi::errors::Fatal("Too many broadcast leaves to compile!"));
}
for (const auto& operand : operands) {
CHECK(!operand.isa<int64_t>());
if (SearchBroadcast(operand, DoEach)) return true;
Expand Down Expand Up @@ -310,13 +304,13 @@ std::optional<symbol::Broadcastable<symbol::DimExpr>> GetFirstCstrBroadcastable(

BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves,
int* num_of_leaves) {
if (*num_of_leaves > FLAGS_pir_broadcast_tree_limit) {
return leaves;
}
std::optional<symbol::Broadcastable<symbol::DimExpr>>
broadcastable_condition = GetFirstCstrBroadcastable(leaves);
if (!broadcastable_condition.has_value()) {
(*num_of_leaves)++;
if (*num_of_leaves > FLAGS_pir_broadcast_tree_limit) {
PADDLE_THROW(phi::errors::Fatal("Too many broadcast leaves to compile!"));
}
return leaves;
}
return ConstructBroadcastBranch(
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/common/broadcast_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#pragma once

#include "paddle/cinn/adt/tree.h"
#include "paddle/common/flags.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr.h"

COMMON_DECLARE_int64(pir_broadcast_tree_limit);

namespace cinn::common {

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h"
#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand All @@ -26,6 +27,8 @@ using OpLoweringGroupPtr = std::shared_ptr<OpLoweringGroup>;
using cinn::dialect::ir::details::CompileGroupAsOpAttribute;
using cinn::dialect::ir::details::GetBlockOutsideInput;

PD_DECLARE_bool(cinn_bc_branch_optimize);

namespace {
std::vector<pir::Value> GetOpOuputValues(const pir::Operation* op) {
std::vector<pir::Value> outputs;
Expand Down Expand Up @@ -438,15 +441,19 @@ void SimplyConditionBlock(

namespace cinn::dialect::ir::details {

std::shared_ptr<BroadcastTree> ConstructBroadcastTree(
std::optional<std::shared_ptr<BroadcastTree>> ConstructBroadcastTree(
const cinn::common::BroadcastLeaf& leaves) {
VLOG(6) << "before constructed. broadcast-leaf: \n"
<< ToTxtString(cinn::common::BroadcastTree(leaves));
int num_of_leaves = 0;
auto broadcast_tree = std::make_shared<cinn::common::BroadcastTree>(
cinn::common::ConstructBroadcastTree(cinn::common::BroadcastLeaf(leaves),
&num_of_leaves));
VLOG(4) << "num of broadcast tree leaves:" << num_of_leaves;
if (num_of_leaves > FLAGS_pir_broadcast_tree_limit) {
LOG(WARNING) << "the number of leaf nodes in broadcast tree exceeds "
"limit.";
return std::nullopt;
}
VLOG(4) << "broadcast-tree: \n" << ToTxtString(*broadcast_tree);
return broadcast_tree;
}
Expand Down Expand Up @@ -478,13 +485,24 @@ GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group) {
return group_dim_expr_info;
}

bool NeedBroadcastWithCF(const OpLoweringGroupPtr& group) {
GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group);
const auto& leaves = group_dim_expr_info.all_value_dim_exprs;
return NeedBroadcastWithCF(leaves);
std::optional<std::shared_ptr<BroadcastTree>> GetBroadcastTreeForOptimize(
const OpLoweringGroupPtr& group) {
if (!FLAGS_cinn_bc_branch_optimize) return std::nullopt;

const common::BroadcastLeaf leaves = [&]() {
// NOTE(dev): Need UpdateShapeOrDataExprs firstly and the logic
// will be migated into BucketLower later.
UpdateGroupShapeOrDataExprs(const_cast<OpLoweringGroupPtr&>(group));
Comment on lines +488 to +495
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: broadcast tree完全接入之后此函数将仅用于pre_compiler判断,将不需要const_cast,这里只生效保底机制故将原来逻辑直接迁移至此保留const_cast

GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group);
return group_dim_expr_info.all_value_dim_exprs;
}();

if (!ContainBroadcastShape(leaves)) return std::nullopt;

return ConstructBroadcastTree(leaves);
}

bool NeedBroadcastWithCF(const cinn::common::BroadcastLeaf& leaves) {
bool ContainBroadcastShape(const cinn::common::BroadcastLeaf& leaves) {
std::optional<symbol::Broadcastable<symbol::DimExpr>>
broadcastable_condition = cinn::common::GetFirstCstrBroadcastable(leaves);
return broadcastable_condition.has_value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ struct GroupDimExprInfo {
std::unordered_map<pir::Value, size_t> value_to_dim_expr_idx;
};

std::shared_ptr<BroadcastTree> ConstructBroadcastTree(
std::optional<std::shared_ptr<BroadcastTree>> ConstructBroadcastTree(
const common::BroadcastLeaf& leaves);

bool NeedBroadcastWithCF(const OpLoweringGroupPtr& group);
bool NeedBroadcastWithCF(const common::BroadcastLeaf& leaves);
std::optional<std::shared_ptr<BroadcastTree>> GetBroadcastTreeForOptimize(
const OpLoweringGroupPtr& group);
bool ContainBroadcastShape(const common::BroadcastLeaf& leaves);
GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group);

pir::Operation* CompileBroadcastTreeToConditionBlock(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,17 @@
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/pass/pass_registry.h"

PD_DECLARE_bool(cinn_bc_branch_optimize);

namespace cinn::dialect::ir::details {

pir::Operation* ProcessDyShapeGroup(const OpLoweringGroupPtr& group,
pir::PatternRewriter& rewriter) { // NOLINT
// NOTE(dev): Need UpdateShapeOrDataExprs firstly and the logic
// will be migated into BucketLower later.
UpdateGroupShapeOrDataExprs(const_cast<OpLoweringGroupPtr&>(group));
auto group_inputs = GetBlockOutsideInput(group->ops());
GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group);
const auto& leaves = group_dim_expr_info.all_value_dim_exprs;
// has multiple branch
if (FLAGS_cinn_bc_branch_optimize && NeedBroadcastWithCF(leaves)) {
const auto& value_to_dim_expr_idx =
group_dim_expr_info.value_to_dim_expr_idx;
const auto& group_inputs = GetBlockOutsideInput(group->ops());
const auto& optional_broadcast_tree = GetBroadcastTreeForOptimize(group);
if (optional_broadcast_tree.has_value()) {
const std::shared_ptr<BroadcastTree> broadcast_tree =
ConstructBroadcastTree(leaves);
optional_broadcast_tree.value();
const auto& value_to_dim_expr_idx =
GetGroupDimExprInfo(group).value_to_dim_expr_idx;
std::vector<pir::Type> output_types;
auto group_output_values = group->GetGroupOutputValues();
for (size_t i = 0; i < group_output_values.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ void FusionOpAnalysis::PreCompileGroup() {

std::vector<OpLoweringGroupPtr> groups;
for (auto& group_info : *group_infos_) {
if (is_dy_shape_ && NeedBroadcastWithCF(group_info.second)) continue;
if (is_dy_shape_ &&
GetBroadcastTreeForOptimize(group_info.second).has_value())
continue;
groups.push_back(group_info.second);
}
// Build and trigger compilaion cache.
Expand Down