From 5845fd0bb44851e9d784d75eeb81e316350918be Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Fri, 15 Jul 2022 13:24:19 +0800 Subject: [PATCH] add runtime inferShape in sequence_pool op (#41) --- .../sequence_ops/sequence_pool_op.cc | 47 +++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc index 01990ebb73291..069e6736fe50d 100644 --- a/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc +++ b/paddle/fluid/operators/sequence_ops/sequence_pool_op.cc @@ -27,7 +27,43 @@ class SequencePoolOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SequencePool"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequencePool"); - if (!ctx->IsRuntime()) { + auto ins_dim = ctx->GetInputDim("X"); + framework::DDim out_dims; + PADDLE_ENFORCE_EQ(ins_dim.size(), 2, + platform::errors::InvalidArgument( + "The dims size of first input should be equal to 2, " + "but received value is %d.", + ins_dim.size())); + + if (ctx->IsRuntime()) { + auto inputs_tensor = ctx->GetInputVarPtrs("X"); + + int batch_size = -1; + int rank = ins_dim.size(); + int cur_batch_size = 0; + framework::Variable* x_var = + BOOST_GET(framework::Variable*, inputs_tensor[0]); + const auto& x_tensor = x_var->Get(); + const auto& x_lod = x_tensor.lod(); + if (x_lod.size() > 0) { + cur_batch_size = x_lod[0].size() - 1; + } else { + cur_batch_size = x_tensor.dims()[0]; + } + if (batch_size == -1) { + batch_size = cur_batch_size; + } else { + PADDLE_ENFORCE_EQ(batch_size, cur_batch_size, + platform::errors::PreconditionNotMet( + "The batch size of all input should be same, " + "please check, last batch_size is %d, current " + "batch_size is %d", + batch_size, cur_batch_size)); + } + std::vector out_dim; + out_dim = {batch_size, ins_dim[rank - 1]}; + out_dims = phi::make_ddim(out_dim); + } else { // Check the lod_level for compile-time. auto in_lod_level = ctx->GetLoDLevel("X"); PADDLE_ENFORCE_GT(in_lod_level, 0, platform::errors::InvalidArgument( @@ -36,13 +72,18 @@ class SequencePoolOp : public framework::OperatorWithKernel { "lod level %u.", in_lod_level)); ctx->SetLoDLevel("Out", in_lod_level - 1); + + int rank = ins_dim.size(); + std::vector out_dim; + out_dim = {-1, ins_dim[rank - 1]}; + out_dims = phi::make_ddim(out_dim); } - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->SetOutputDim("Out", out_dims); if (ctx->Attrs().Get("pooltype") == "MAX") { OP_INOUT_CHECK(ctx->HasOutput("MaxIndex"), "Output", "MaxIndex", "SequencePool"); - ctx->SetOutputDim("MaxIndex", ctx->GetInputDim("X")); + ctx->SetOutputDim("MaxIndex", out_dims); } } };