Skip to content

Commit

Permalink
fix error msg (PaddlePaddle#27887)
Browse files Browse the repository at this point in the history
  • Loading branch information
hutuxian authored Oct 13, 2020
1 parent 426de25 commit 3f2a6ab
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
19 changes: 12 additions & 7 deletions paddle/fluid/operators/pull_box_sparse_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL,
"Inputs(Ids) of PullBoxSparseOp should not be empty.");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
"Outputs(Out) of PullBoxSparseOp should not be empty.");
PADDLE_ENFORCE_GE(
ctx->Inputs("Ids").size(), 1UL,
platform::errors::InvalidArgument(
"Inputs(Ids) of PullBoxSparseOp should not be empty."));
PADDLE_ENFORCE_GE(
ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of PullBoxSparseOp should not be empty."));
auto hidden_size = static_cast<int64_t>(ctx->Attrs().Get<int>("size"));
auto all_ids_dim = ctx->GetInputsDim("Ids");
const size_t n_ids = all_ids_dim.size();
Expand All @@ -34,9 +38,10 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
const auto ids_dims = all_ids_dim[i];
int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
"Shape error in %lu id, the last dimension of the "
"'Ids' tensor must be 1.",
i);
platform::errors::InvalidArgument(
"Shape error in %lu id, the last dimension of the "
"'Ids' tensor must be 1.",
i));
auto out_dim = framework::vectorize(
framework::slice_ddim(ids_dims, 0, ids_rank - 1));
out_dim.push_back(hidden_size);
Expand Down
15 changes: 15 additions & 0 deletions python/paddle/fluid/tests/unittests/test_boxps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import numpy as np
Expand Down Expand Up @@ -87,5 +88,19 @@ def test_run_cmd(self):
self.assertTrue(ret2 == 0)


class TestPullBoxSparseOP(unittest.TestCase):
""" TestCases for _pull_box_sparse op"""

def test_pull_box_sparse_op(self):
paddle.enable_static()
program = fluid.Program()
with fluid.program_guard(program):
x = fluid.layers.data(
name='x', shape=[1], dtype='int64', lod_level=0)
y = fluid.layers.data(
name='y', shape=[1], dtype='int64', lod_level=0)
emb_x, emb_y = _pull_box_sparse([x, y], size=1)


if __name__ == '__main__':
unittest.main()

0 comments on commit 3f2a6ab

Please sign in to comment.