Skip to content

Commit

Permalink
add runtime inferShape in sequence_pool op (PaddlePaddle#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
WorgenZhang authored and zmxdream committed Sep 5, 2022
1 parent 3e8d850 commit 5845fd0
Showing 1 changed file with 44 additions and 3 deletions.
47 changes: 44 additions & 3 deletions paddle/fluid/operators/sequence_ops/sequence_pool_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,43 @@ class SequencePoolOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequencePool");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequencePool");

if (!ctx->IsRuntime()) {
auto ins_dim = ctx->GetInputDim("X");
framework::DDim out_dims;
PADDLE_ENFORCE_EQ(ins_dim.size(), 2,
platform::errors::InvalidArgument(
"The dims size of first input should be equal to 2, "
"but received value is %d.",
ins_dim.size()));

if (ctx->IsRuntime()) {
auto inputs_tensor = ctx->GetInputVarPtrs("X");

int batch_size = -1;
int rank = ins_dim.size();
int cur_batch_size = 0;
framework::Variable* x_var =
BOOST_GET(framework::Variable*, inputs_tensor[0]);
const auto& x_tensor = x_var->Get<LoDTensor>();
const auto& x_lod = x_tensor.lod();
if (x_lod.size() > 0) {
cur_batch_size = x_lod[0].size() - 1;
} else {
cur_batch_size = x_tensor.dims()[0];
}
if (batch_size == -1) {
batch_size = cur_batch_size;
} else {
PADDLE_ENFORCE_EQ(batch_size, cur_batch_size,
platform::errors::PreconditionNotMet(
"The batch size of all input should be same, "
"please check, last batch_size is %d, current "
"batch_size is %d",
batch_size, cur_batch_size));
}
std::vector<int64_t> out_dim;
out_dim = {batch_size, ins_dim[rank - 1]};
out_dims = phi::make_ddim(out_dim);
} else {
// Check the lod_level for compile-time.
auto in_lod_level = ctx->GetLoDLevel("X");
PADDLE_ENFORCE_GT(in_lod_level, 0, platform::errors::InvalidArgument(
Expand All @@ -36,13 +72,18 @@ class SequencePoolOp : public framework::OperatorWithKernel {
"lod level %u.",
in_lod_level));
ctx->SetLoDLevel("Out", in_lod_level - 1);

int rank = ins_dim.size();
std::vector<int64_t> out_dim;
out_dim = {-1, ins_dim[rank - 1]};
out_dims = phi::make_ddim(out_dim);
}

ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->SetOutputDim("Out", out_dims);
if (ctx->Attrs().Get<std::string>("pooltype") == "MAX") {
OP_INOUT_CHECK(ctx->HasOutput("MaxIndex"), "Output", "MaxIndex",
"SequencePool");
ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X"));
ctx->SetOutputDim("MaxIndex", out_dims);
}
}
};
Expand Down

0 comments on commit 5845fd0

Please sign in to comment.