Skip to content

Commit

Permalink
Modify Scale Compute to Support Mix Precision (PaddlePaddle#58811)
Browse files Browse the repository at this point in the history
Modify Scale Compute to Support Mix Precision
  • Loading branch information
zhhsplendid authored Nov 13, 2023
1 parent e4ec410 commit 8f5e4b1
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions paddle/cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,31 @@ std::shared_ptr<OpStrategy> StrategyForScale(
CHECK(pack_args[1].is_string());
std::string tensor_name = pack_args[1].operator std::string();

if (bias_after_scale) {
out = Compute(
A->shape,
[=](const std::vector<Expr> &indice) {
return ir::Cast::Make(A->type(), Expr(scale)) * A(indice) +
ir::Cast::Make(A->type(), Expr(bias));
},
tensor_name);
} else {
out = Compute(
A->shape,
[=](const std::vector<Expr> &indice) {
return ir::Cast::Make(A->type(), Expr(scale)) *
(A(indice) + ir::Cast::Make(A->type(), Expr(bias)));
},
tensor_name);
}
// Paddle upscale float16 or bfloat16 compute to float32,
// we made CINN consistent with this behavior of Paddle
bool should_upscale_fp32 =
A->type() == common::F16() || A->type() == common::BF16();

out = Compute(
A->shape,
[=](const std::vector<Expr> &indice) {
Expr cast_scale = should_upscale_fp32
? Expr(scale)
: ir::Cast::Make(A->type(), Expr(scale));
Expr cast_bias = should_upscale_fp32
? Expr(bias)
: ir::Cast::Make(A->type(), Expr(bias));
Expr cast_A_indice =
should_upscale_fp32 ? ir::Cast::Make(common::F32(), A(indice))
: A(indice);
Expr add_result = bias_after_scale
? cast_scale * cast_A_indice + cast_bias
: cast_scale * (cast_A_indice + cast_bias);
return should_upscale_fp32 ? ir::Cast::Make(A->type(), add_result)
: add_result;
},
tensor_name);

auto stages = CreateStages({out});
*ret = CINNValuePack{{CINNValue(Expr(out.get())), CINNValue(stages)}};
});
Expand Down

0 comments on commit 8f5e4b1

Please sign in to comment.