Skip to content

[Operator]Delete XShape for squeeze output #67355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,9 @@ class SqueezeOpPattern
in_shape[i]));
}
}

ReplaceWithCinnReshapeOp(op, rewriter, output_shape);
auto cinn_reshape = rewriter.Build<cinn::dialect::ReshapeOp>(
op->operand_source(0), output_shape);
rewriter.ReplaceAllUsesWith(op.result(0), cinn_reshape.result(0));
rewriter.EraseOp(op);

return true;
Expand Down Expand Up @@ -956,7 +957,6 @@ class UnsqueezeOpPattern
output_shape.push_back(1);
}
}

ReplaceWithCinnReshapeOp(op, rewriter, output_shape);
rewriter.EraseOp(op);

Expand Down
4 changes: 0 additions & 4 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,8 @@
"imag",
"diagonal",
"flatten",
"flatten_infer",
"reshape",
"slice",
"squeeze_infer",
"squeeze",
"strided_slice",
"strided_slice_raw",
Expand All @@ -164,9 +162,7 @@
"real_",
"imag_",
"diagonal_",
"flatten_infer_",
"slice_",
"squeeze_infer_",
"strided_slice_",
"strided_slice_raw_",
"tensor_unfold_",
Expand Down
159 changes: 112 additions & 47 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3591,6 +3591,92 @@ struct QuantizeLinearOpTranscriber : public OpTranscriber {
}
};

// NOTE(Dev): heleper funtions for WithXShapeGradOpTranscriber
static std::pair<pir::Value, pir::Value> ParseXAndOutGradValue(
const OpDesc& op_desc,
pir::IrContext* ctx,
pir::Builder* builder,
TranslationContext* param_map,
pir::Block* block) {
auto& input_xshape_name = op_desc.Input("XShape")[0];
auto& input_outgrad_name = op_desc.Input("Out@GRAD")[0];
pir::Value xshape_value;
VLOG(10) << "create data op for " << input_xshape_name;
auto var_desc = op_desc.Block()->FindVarRecursive(input_xshape_name);
auto dtype = ::phi::TransToPhiDataType(var_desc->GetDataType());
auto shape_vec = var_desc->GetShape();
// NOTE(dev): GrapOp depends on X instead of XShape, so we need
// earse fisrt element in xshape.
shape_vec.erase(shape_vec.begin());
xshape_value = builder
->Build<paddle::dialect::DataOp>(
input_xshape_name, shape_vec, dtype, phi::Place())
.result(0);

VLOG(10) << "create data op for " << input_xshape_name << " done";

if (param_map->Has(input_xshape_name)) {
auto value =
param_map->at(input_xshape_name).value.dyn_cast<pir::OpResult>();
auto* defining_op = value.owner();
value.ReplaceAllUsesWith(xshape_value);
param_map->PopValue(input_xshape_name);
defining_op->Erase();
}

param_map->PushValue(input_xshape_name, xshape_value);
PADDLE_ENFORCE_EQ(param_map->Has(input_outgrad_name),
true,
common::errors::InvalidArgument(
"Reshape2_Grad op does not have input Out@GRAD"));
auto input_outgrad_value_info = param_map->at(input_outgrad_name);
if (input_outgrad_value_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, block, input_outgrad_value_info, input_outgrad_name);
input_outgrad_value_info = param_map->at(input_outgrad_name);
}
pir::Value input_outgrad_value = input_outgrad_value_info.value;

PADDLE_ENFORCE_EQ(
input_outgrad_value.type().isa<paddle::dialect::DenseTensorType>(),
true,
::common::errors::InvalidArgument(
"input type must be DenseTensorType, but received: %s.",
input_outgrad_value.type()));

return std::make_pair(xshape_value, input_outgrad_value);
}

static pir::Value ParseAxis(const OpDesc& op_desc,
TranslationContext* param_map,
pir::IrContext* ctx,
pir::Block* block) {
// process axes
if (op_desc.HasInput("AxesTensor") && !op_desc.Input("AxesTensor").empty()) {
// get axis from input
auto axis_var_list = op_desc.Input("AxesTensor");
PADDLE_ENFORCE_EQ(
axis_var_list.size(),
1UL,
common::errors::InvalidArgument(
"axis tensor input of %s MUST be a tensor", op_desc.Type()));
auto axis_defining_info = (*param_map)[axis_var_list[0]];
return axis_defining_info.value;
} else if (op_desc.HasInput("AxesTensorList") &&
!op_desc.Input("AxesTensorList").empty()) {
auto* combine_op = InsertCombineOperationForTarget(
ctx, param_map, block, op_desc.Input("AxesTensorList"));
return combine_op->result(0);
} else {
auto& attribute_translator = AttributeTranslator::instance();
pir::Attribute new_attr = attribute_translator(
"paddle::dialect::IntArrayAttribute", op_desc.GetAttr("axes"));
auto full_array_op =
InsertFullArrayOperationForAttributeInput(ctx, block, new_attr);
return full_array_op->result(0);
}
}

template <typename OpT>
struct WithXShapeGradOpTranscriber : public OpTranscriber {
pir::Operation* operator()(pir::IrContext* ctx,
Expand All @@ -3599,53 +3685,9 @@ struct WithXShapeGradOpTranscriber : public OpTranscriber {
pir::Block* block) override {
VLOG(4) << "Translate " << op_desc.Type() << ".....";
pir::Builder builder(ctx, block);
auto& input_xshape_name = op_desc.Input("XShape")[0];
auto& input_outgrad_name = op_desc.Input("Out@GRAD")[0];
auto [xshape_value, input_outgrad_value] =
ParseXAndOutGradValue(op_desc, ctx, &builder, param_map, block);
auto& out_name = op_desc.Output("X@GRAD")[0];
pir::Value xshape_value;
VLOG(10) << "create data op for " << input_xshape_name;
auto var_desc = op_desc.Block()->FindVarRecursive(input_xshape_name);
auto dtype = ::phi::TransToPhiDataType(var_desc->GetDataType());
auto shape_vec = var_desc->GetShape();
shape_vec.erase(shape_vec.begin());
xshape_value = builder
.Build<paddle::dialect::DataOp>(
input_xshape_name, shape_vec, dtype, phi::Place())
.result(0);

VLOG(10) << "create data op for " << input_xshape_name << " done";

if (param_map->Has(input_xshape_name)) {
auto value =
param_map->at(input_xshape_name).value.dyn_cast<pir::OpResult>();
auto* defining_op = value.owner();
value.ReplaceAllUsesWith(xshape_value);
param_map->PopValue(input_xshape_name);
defining_op->Erase();
}

param_map->PushValue(input_xshape_name, xshape_value);
auto* defining_op = xshape_value.dyn_cast<pir::OpResult>().owner();
auto attr_map = defining_op->attributes();

PADDLE_ENFORCE_EQ(param_map->Has(input_outgrad_name),
true,
common::errors::InvalidArgument(
"Reshape2_Grad op does not have input Out@GRAD"));
auto input_outgrad_value_info = param_map->at(input_outgrad_name);
if (input_outgrad_value_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, block, input_outgrad_value_info, input_outgrad_name);
input_outgrad_value_info = param_map->at(input_outgrad_name);
}
pir::Value input_outgrad_value = input_outgrad_value_info.value;

PADDLE_ENFORCE_EQ(
input_outgrad_value.type().isa<paddle::dialect::DenseTensorType>(),
true,
::common::errors::InvalidArgument(
"input type must be DenseTensorType, but received: %s.",
input_outgrad_value.type()));
// NOTE(Aurelius84): Even though we use xshape to construct grad op,
// but in GradKernel we still use dx->dims by default.
OpT grad_op = builder.Build<OpT>(xshape_value, input_outgrad_value);
Expand All @@ -3655,6 +3697,28 @@ struct WithXShapeGradOpTranscriber : public OpTranscriber {
}
};

// NOTE(dev): In case of squeeze_grad and unsqueeze_grad
template <typename OpT>
struct WithXShapeAndAxisGradOpTranscriber : public OpTranscriber {
pir::Operation* operator()(pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
pir::Block* block) override {
VLOG(4) << "Translate " << op_desc.Type() << ".....";
pir::Builder builder(ctx, block);
auto [x_value, input_outgrad_value] =
ParseXAndOutGradValue(op_desc, ctx, &builder, param_map, block);
auto& out_name = op_desc.Output("X@GRAD")[0];
// NOTE(Aurelius84): Even though we use xshape to construct grad op,
// but in GradKernel we still use dx->dims by default.
pir::Value axis = ParseAxis(op_desc, param_map, ctx, block);
OpT grad_op = builder.Build<OpT>(x_value, input_outgrad_value, axis);
param_map->PushValue(out_name, grad_op.result(0));

return grad_op.operation();
}
};

OpTranslator::OpTranslator() {
pir::IrContext* ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
Expand Down Expand Up @@ -3752,7 +3816,8 @@ OpTranslator::OpTranslator() {
WithXShapeGradOpTranscriber<dialect::ReshapeGradOp>();
special_handlers["flatten_contiguous_range_grad"] =
WithXShapeGradOpTranscriber<dialect::FlattenGradOp>();
special_handlers["squeeze2_grad"] =
WithXShapeAndAxisGradOpTranscriber<dialect::SqueezeGradOp>();
}

} // namespace translator
} // namespace paddle
Original file line number Diff line number Diff line change
Expand Up @@ -2271,8 +2271,6 @@ bool SqueezeOpInferSymbolicShape(

pir::Value res = op->result(0);
infer_context->SetShapeOrDataForValue(res, shape_data);
infer_context->SetShapeOrDataForValue(
op->result(1), CreateShapeOrDataForXShape(x_shape_or_data));

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,10 @@ class FusedRotaryPositionEmbeddingPattern : public paddle::drr::DrrPatternBase {
const auto &concat_op_k = pat.Op(paddle::dialect::ConcatOp::name());
const auto &combine_k = pat.Op(pir::CombineOp::name());

squeeze({&pat.Tensor("cos"), &full_13()},
{&pat.Tensor("squeeze_out_cos"), &pat.Tensor("xshape")});
squeeze({&pat.Tensor("cos"), &full_13()}, {&pat.Tensor("squeeze_out_cos")});

squeeze_1({&pat.Tensor("sin"), &full_12()},
{&pat.Tensor("squeeze_out_sin"), &pat.Tensor("xshape")});
{&pat.Tensor("squeeze_out_sin")});

unsqueeze({&pat.Tensor("position_ids"), &full_11()},
{&pat.Tensor("unsqueeze_s_out_cos"), &pat.Tensor("xshape")});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ class SqueezeTransposePattern : public paddle::drr::DrrPatternBase {
const auto &full_1 = pat.Op(paddle::dialect::FullIntArrayOp::name(),
{{"value", pat.Attr("full_1_value")}});

squeeze({&pat.Tensor("x"), &full_1()},
{&pat.Tensor("squeeze_out"), &pat.Tensor("xshape")});
squeeze({&pat.Tensor("x"), &full_1()}, {&pat.Tensor("squeeze_out")});

const auto &transpose = pat.Op(paddle::dialect::TransposeOp::name(),
{{"perm", pat.Attr("perm")}});
Expand Down
8 changes: 3 additions & 5 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,11 @@ Tensor relu6_decomp(const Tensor& x) {
}

template <typename T>
std::tuple<Tensor, Tensor> squeeze_decomp(const Tensor& x,
const IntArray& axis) {
Tensor squeeze_decomp(const Tensor& x, const IntArray& axis) {
auto axis_ = process_dims(x, axis.GetData());
auto out_shape = get_squeeze_dims(x, axis_);
Tensor out = reshape<T>(x, out_shape);
Tensor xshape;
return std::make_tuple(out, xshape);
return out;
}

template <typename T>
Expand Down Expand Up @@ -1460,7 +1458,7 @@ Tensor embedding_decomp(const Tensor& x,
if (x.dims().size() <= 1) {
res = gather<T>(weight_tmp, x);
if (x.dims().size() == 0) {
res = std::get<0>(squeeze_decomp<T>(res, {0}));
res = squeeze_decomp<T>(res, {0});
}
} else {
std::vector<int64_t> tar_shape{-1};
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ void softmax_grad(const Tensor& out,
}

template <typename T>
void squeeze_grad(const Tensor& xshape,
void squeeze_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
Tensor* x_grad) {
Expand Down
18 changes: 4 additions & 14 deletions paddle/phi/infermeta/spmd_rules/squeeze.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ namespace distributed {

using phi::distributed::auto_parallel::str_join;

TensorDistAttr CreateSqueezeXshape(const TensorDistAttr& x) {
TensorDistAttr out(x);
auto dims_mapping = x.dims_mapping();
dims_mapping.insert(dims_mapping.begin(), -1);
out.set_dims_mapping(dims_mapping);
return out;
}

void MakeSqueezeDimTransWithoutAxis(
const std::vector<int64_t>& x_shape,
std::vector<int64_t>* out_shape,
Expand Down Expand Up @@ -168,8 +160,7 @@ SpmdInfo SqueezeInferSpmd(const DistMetaTensor& x,
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "]\n\n";

return {{x_dist_attr_dst},
{out_dist_attr, CreateSqueezeXshape(x_dist_attr_dst)}};
return {{x_dist_attr_dst}, {out_dist_attr}};
}

SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
Expand Down Expand Up @@ -246,13 +237,12 @@ SpmdInfo SqueezeInferSpmdReverse(const DistMetaTensor& x,
return {{x_dist_attr}, {out_dist_attr_dst}};
}

SpmdInfo SqueezeGradInferSpmd(const DistMetaTensor& xshape,
SpmdInfo SqueezeGradInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& out_grad,
const IntArray& axis) {
auto shape = phi::vectorize(xshape.dims());
shape = std::vector<int64_t>(shape.begin() + 1, shape.end());
auto shape = phi::vectorize(x.dims());
const auto& spmd = ReshapeInferSpmd(out_grad, shape);
return {{xshape.dist_attr(), spmd.first[0]}, {spmd.second[0]}};
return {{x.dist_attr(), spmd.first[0]}, {spmd.second[0]}};
}

} // namespace distributed
Expand Down
30 changes: 15 additions & 15 deletions paddle/phi/kernels/onednn/squeeze_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ void ExecuteSqueeze(const Context& dev_ctx,
}

template <typename T, typename Context>
void SqueezeInferKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out) {
auto x_dims = x.dims();
auto x_dims_tz = x_dims.size();
std::vector<int32_t> tmp(axes.GetData().begin(), axes.GetData().end());
Expand Down Expand Up @@ -87,13 +87,13 @@ void SqueezeInferKernel(const Context& dev_ctx,
}

template <typename T, typename Context>
void SqueezeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
void SqueezeWithXShapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
if (xshape == nullptr) {
SqueezeInferKernel<T, Context>(dev_ctx, x, axes, out);
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
} else {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims();
Expand All @@ -102,12 +102,12 @@ void SqueezeKernel(const Context& dev_ctx,
}
} // namespace phi

PD_REGISTER_KERNEL(squeeze_infer,
PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(squeeze_with_xshape,
OneDNN,
ONEDNN,
phi::SqueezeInferKernel,
phi::SqueezeWithXShapeKernel,
float,
phi::dtype::bfloat16) {}

PD_REGISTER_KERNEL(
squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {}
Loading