Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… clean
  • Loading branch information
zhangbopd committed Jan 10, 2024
2 parents f79fdb2 + da5399a commit ee57e81
Show file tree
Hide file tree
Showing 126 changed files with 3,400 additions and 971 deletions.
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231218")
endif()
if(NOT DEFINED XPU_XHPC_BASE_DATE)
set(XPU_XHPC_BASE_DATE "20231229")
set(XPU_XHPC_BASE_DATE "20240105")
endif()
set(XPU_XCCL_BASE_VERSION "1.1.8.1")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
61 changes: 17 additions & 44 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,51 +314,14 @@ class SubstituteDimExprHelper final {
DimExpr4SymbolNameT DimExpr4SymbolName_;
};

std::optional<DimExpr> SubstituteDimExpr(
DimExpr SubstituteDimExpr(
const DimExpr& dim_expr,
const std::function<std::optional<DimExpr>(const std::string& symbol_name)>&
DimExpr4SymbolName) {
return SubstituteDimExprHelper(DimExpr4SymbolName).Substitute(dim_expr);
}

std::function<std::optional<DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const std::vector<std::tuple<std::string /*symbol_name*/,
int /*in_tensor_idx*/,
int /*in_tensor_dim_idx*/>>& symbol_bindings,
const std::function<std::optional<DimExpr>(
int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim) {
std::unordered_map<std::string, std::vector<std::pair<int, int>>>
symbol_name2in_tensor_dim_pos;
for (const auto& tuple : symbol_bindings) {
const auto& [symbol_name, in_tensor_idx, in_tensor_dim_idx] = tuple;
symbol_name2in_tensor_dim_pos[symbol_name].emplace_back(
std::pair{in_tensor_idx, in_tensor_dim_idx});
}
return [map = std::move(symbol_name2in_tensor_dim_pos), DimExpr4InputDim](
const std::string& symbol_name) -> std::optional<DimExpr> {
const auto& iter = map.find(symbol_name);
if (iter == map.end()) {
return std::nullopt;
}
const auto& positions = iter->second;
std::optional<DimExpr> ret = std::nullopt;
for (const auto& [in_tensor_idx, in_tensor_dim_idx] : positions) {
const auto& current = DimExpr4InputDim(in_tensor_idx, in_tensor_dim_idx);
if (!current.has_value()) {
return std::nullopt;
}
if (ret.has_value()) {
// Same names, same DimExprs.
if (ret.value() != current.value()) {
return std::nullopt;
}
} else {
ret = current;
}
}
return ret;
};
const auto& opt_substituted =
SubstituteDimExprHelper(DimExpr4SymbolName).Substitute(dim_expr);
if (opt_substituted.has_value()) return opt_substituted.value();
return dim_expr;
}

namespace {
Expand Down Expand Up @@ -387,6 +350,12 @@ std::optional<DimExpr> GetDimExprBySymbolBindingImpl(
return shape_or_data_dim_expr.shape().at(dim_idx);
}

std::string GetSymbolNameBySymbolBinding(
const GenerateShapeOp::SymbolBinding& symbol_binding) {
return std::visit([](const auto& impl) { return impl.symbol_name; },
symbol_binding);
}

} // namespace

std::function<std::optional<DimExpr>(const std::string& symbol_name)>
Expand All @@ -396,6 +365,10 @@ MakeGetterDimExpr4SymbolName(
DimExpr4InputDim) {
std::unordered_map<std::string, std::vector<GenerateShapeOp::SymbolBinding>>
symbol_name2symbol_bindins{};
for (const auto& symbol_binding : symbol_bindings) {
symbol_name2symbol_bindins[GetSymbolNameBySymbolBinding(symbol_binding)]
.emplace_back(symbol_binding);
}
const auto& GetDimExpr =
[&](const GenerateShapeOp::SymbolBinding& symbol_binding) {
return std::visit(
Expand Down Expand Up @@ -596,14 +569,14 @@ void GenerateSymbolBindings(
std::vector<pir::Value> GetMinimalInputs(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value,
const std::vector<pir::Value>& input_tensors) {
std::unordered_set<symbol::DimExpr> handdled_dim_exprs;
std::unordered_set<symbol::DimExpr> handled_dim_exprs;
std::unordered_set<pir::Value> first_occurred_input_tensors;
auto TryCollectFirstOcurredInput_tensor =
[&](pir::Value input_tensor,
const std::vector<symbol::DimExpr>& dim_exprs) {
for (const auto& dim_expr : dim_exprs) {
if (dim_expr.isa<int64_t>()) continue;
if (!handdled_dim_exprs.insert(dim_expr).second) {
if (handled_dim_exprs.insert(dim_expr).second) {
first_occurred_input_tensors.insert(input_tensor);
}
}
Expand Down
12 changes: 2 additions & 10 deletions paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,19 @@ ::pir::Attribute ConvertDimExprToAttribute(pir::IrContext* ctx,
std::optional<symbol::DimExpr> ConvertAttributeToDimExpr(
::pir::Attribute attribute);

std::optional<symbol::DimExpr> SubstituteDimExpr(
symbol::DimExpr SubstituteDimExpr(
const symbol::DimExpr& dim_expr,
const std::function<std::optional<symbol::DimExpr>(
const std::string& symbol_name)>& DimExpr4SymbolName);

std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const std::vector<std::tuple<std::string /*symbol_name*/,
int /*in_tensor_idx*/,
int /*in_tensor_dim_idx*/>>& symbol_bindings,
const std::function<std::optional<symbol::DimExpr>(
int in_tensor_idx, int in_tensor_dim_idx)>& DimExpr4InputDim);

std::function<std::optional<symbol::DimExpr>(const std::string& symbol_name)>
MakeGetterDimExpr4SymbolName(
const GenerateShapeOp::SymbolBindings& symbol_bindings,
const std::function<const symbol::ShapeOrDataDimExprs&(int in_tensor_idx)>&
DimExpr4InputDim);

using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs&(pir::Value)>;
std::function<symbol::ShapeOrDataDimExprs(pir::Value)>;

// Returns true if success.
bool MakeGenerateShapeOpAttribute(
Expand Down
1 change: 1 addition & 0 deletions paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void OperatorDialect::initialize() {
RegisterOp<GroupOp>();
RegisterOp<ConcatOp>();
RegisterOp<SplitOp>();
RegisterOp<GenerateShapeOp>();
RegisterAttribute<GroupInfoAttribute>();
RegisterAttribute<CINNKernelInfoAttribute>();
}
Expand Down
9 changes: 9 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,13 @@ if(NOT CINN_ONLY)
cinn_op_dialect
op_dialect_vjp)

cinn_cc_library(
split_generate_shape_into_shape_ops_pass
SRCS
split_generate_shape_into_shape_ops_pass.cc
DEPS
pir
cinn_op_dialect
op_dialect_vjp)

endif()
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ namespace cinn {
namespace dialect {
namespace ir {

namespace {

pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
pir::Value x,
pir::Value y) {
Expand All @@ -42,6 +44,10 @@ pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter,
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
if (op->operand_source(0).defining_op()->isa<paddle::dialect::ExpandOp>() &&
op->operand_source(1).defining_op()->isa<paddle::dialect::ExpandOp>()) {
return false;
}
pir::Value x = op->operand_source(0);
pir::Value y = op->operand_source(1);
pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y);
Expand All @@ -58,6 +64,8 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
return true;
}

} // namespace

template <typename OPTYPE>
class FullyInsertBroadcastPattern : public pir::OpRewritePattern<OPTYPE> {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/dialect/shape/utils/shape_utils.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
Expand All @@ -38,6 +39,9 @@ namespace ir {

namespace {

using ShapeOrDataDimExprs4ValueT =
std::function<symbol::ShapeOrDataDimExprs(pir::Value)>;

std::vector<pir::Value> FindSourceDenseTensorOfDimTensor(
pir::Value shape,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
Expand Down Expand Up @@ -126,9 +130,17 @@ std::optional<pir::Value> GetOutOfRewritedGenerateShapeOp(
.out();
}

bool ProcessOp(paddle::dialect::ExpandOp op,
pir::PatternRewriter* rewriter,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) {
bool ProcessOp(paddle::dialect::ExpandOp op, pir::PatternRewriter* rewriter) {
if (op.shape().defining_op()->isa<cinn::dialect::GenerateShapeOp>()) {
return false;
}
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value =
[&op](pir::Value value) -> symbol::ShapeOrDataDimExprs {
pir::ShapeConstraintIRAnalysis& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(
op.x().defining_op()->GetParentProgram());
return shape_analysis.GetShapeOrDataForValue(value);
};
std::optional<pir::Value> opt_generated_shape =
GetOutOfRewritedGenerateShapeOp(
op.shape(), rewriter, ShapeOrDataDimExprs4Value);
Expand All @@ -143,32 +155,25 @@ template <typename OPTYPE>
class FuseShapeOpsIntoGenerateShapeOpPattern
: public pir::OpRewritePattern<OPTYPE> {
public:
FuseShapeOpsIntoGenerateShapeOpPattern(
pir::IrContext* context,
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value)
: pir::OpRewritePattern<OPTYPE>(context),
ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {}
explicit FuseShapeOpsIntoGenerateShapeOpPattern(pir::IrContext* context)
: pir::OpRewritePattern<OPTYPE>(context) {}

bool MatchAndRewrite(OPTYPE op,
pir::PatternRewriter& rewriter) const override {
return ProcessOp(op, &rewriter, ShapeOrDataDimExprs4Value_);
return ProcessOp(op, &rewriter);
}

private:
ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_;
};

FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass(
const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value)
: pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1),
ShapeOrDataDimExprs4Value_(ShapeOrDataDimExprs4Value) {}
FuseShapeOpsIntoGenerateShapeOpPass::FuseShapeOpsIntoGenerateShapeOpPass()
: pir::PatternRewritePass("fuse_shape_ops_into_generate_shape_op_pass", 1) {
}

pir::RewritePatternSet FuseShapeOpsIntoGenerateShapeOpPass::InitializePatterns(
pir::IrContext* context) {
pir::RewritePatternSet ps(context);
// elementwise ops
ps.Add<FuseShapeOpsIntoGenerateShapeOpPattern<paddle::dialect::ExpandOp>>(
context, ShapeOrDataDimExprs4Value_);
context);

return ps;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,11 @@ namespace ir {

class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass {
public:
using ShapeOrDataDimExprs4ValueT =
std::function<const symbol::ShapeOrDataDimExprs &(pir::Value)>;
explicit FuseShapeOpsIntoGenerateShapeOpPass(
const ShapeOrDataDimExprs4ValueT &ShapeOrDataDimExprs4Value);
FuseShapeOpsIntoGenerateShapeOpPass();

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;

bool CanApplyOn(pir::Operation *op) const override;

private:
ShapeOrDataDimExprs4ValueT ShapeOrDataDimExprs4Value_;
};

} // namespace ir
Expand Down
Loading

0 comments on commit ee57e81

Please sign in to comment.