Skip to content

Commit

Permalink
Merge branch 'feature/dynamic-recurrent-op' into feature/dynamic-recu…
Browse files Browse the repository at this point in the history
…rrent-op-forward-test
  • Loading branch information
Superjomn committed Oct 9, 2017
2 parents 50c364e + d30ada2 commit 4a0cc85
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 12 deletions.
4 changes: 4 additions & 0 deletions cmake/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ if(WITH_DOUBLE)
add_definitions(-DPADDLE_TYPE_DOUBLE)
endif(WITH_DOUBLE)

if(WITH_TESTING)
add_definitions(-DPADDLE_WITH_TESTING)
endif(WITH_TESTING)

if(NOT WITH_TIMER)
add_definitions(-DPADDLE_DISABLE_TIMER)
endif(NOT WITH_TIMER)
Expand Down
10 changes: 5 additions & 5 deletions paddle/operators/dynamic_recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ inline void CreateVariables(Scope& scope,

} // namespace detail

class DynamicRecurrentAlgorithmProtoAndCheckerMaker
class DynamicRecurrentOpProtoAndCheckerMaker
: public framework::OpProtoAndCheckerMaker {
public:
DynamicRecurrentAlgorithmProtoAndCheckerMaker(
framework::OpProto* proto, framework::OpAttrChecker* op_checker)
DynamicRecurrentOpProtoAndCheckerMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
const auto& name = DynamicRecurrentOp::kArgName;
// inputs and outputs stored in proto
Expand Down Expand Up @@ -268,5 +268,5 @@ void DynamicRecurrentGradientOp::Run(
} // namespace paddle

REGISTER_OP_WITHOUT_GRADIENT(
dynamic_recurrent_op, paddle::operators::DynamicRecurrentOp,
paddle::operators::DynamicRecurrentAlgorithmProtoAndCheckerMaker);
dynamic_recurrent, paddle::operators::DynamicRecurrentOp,
paddle::operators::DynamicRecurrentOpProtoAndCheckerMaker);
4 changes: 4 additions & 0 deletions paddle/operators/dynamic_recurrent_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

#pragma once

#ifdef PADDLE_WITH_TESTING
#include "gtest/gtest.h"
#endif

#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/operator.h"
Expand Down Expand Up @@ -142,6 +144,7 @@ class DynamicRecurrentOp : public framework::OperatorBase {
mutable rnn::Argument arg_;
mutable ArgCache cache_;

#ifdef PADDLE_WITH_TESTING
friend class DynamicRecurrentOpTestHelper;
FRIEND_TEST(DynamicRecurrentOpTestHelper, SplitInputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, CreateCache);
Expand All @@ -150,6 +153,7 @@ class DynamicRecurrentOp : public framework::OperatorBase {
FRIEND_TEST(DynamicRecurrentOpTestHelper, WriteStepOutputs);
FRIEND_TEST(DynamicRecurrentOpTestHelper, InitStates);
FRIEND_TEST(DynamicRecurrentOpTestHelper, ConcatOutputs);
#endif
};

class DynamicRecurrentGradientOp : public framework::OperatorBase {
Expand Down
5 changes: 4 additions & 1 deletion paddle/operators/dynamic_recurrent_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
CreateVar(scope, "out0", framework::make_ddim({10, 20}), place);
auto* in0 = CreateVar(scope, "in0", framework::make_ddim({10, 8}), place);
// 10 instanes with 4 sentences, length is 4, 3, 2, 1 respectively.
framework::LoD in0_lod({{0, 4, 7, 9, 10}});
framework::LoD in0_lod(1);
for (int x : std::vector<int>{0, 4, 7, 9, 10}) {
in0_lod[0].push_back(x);
}
in0->set_lod(in0_lod);
in0->Resize(framework::make_ddim({10, 8}));
// set the content, each sentence content is seqid.batchid
Expand Down
11 changes: 5 additions & 6 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,14 +376,13 @@ All parameter, weight, gradient are variables in Paddle.
-> void { self.SetStepNet(net.Clone()); })
.def("get_state",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> TensorArray & { return self.state(name); })
-> const TensorArray & { return self.state(name); })
.def("get_step_input",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> TensorArray & { return self.step_input(name); })
.def("get_step_output", [](operators::DynamicRecurrentOp &self,
const std::string &name) -> TensorArray & {
return self.step_output(name);
});
-> const TensorArray & { return self.step_input(name); })
.def("get_step_output",
[](operators::DynamicRecurrentOp &self, const std::string &name)
-> const TensorArray & { return self.step_output(name); });

// cond_op
py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
Expand Down
22 changes: 22 additions & 0 deletions python/paddle/v2/framework/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,27 @@ def __call__(self, *args, **kwargs):
return core.RecurrentOp.create(proto.SerializeToString())


class __DynamicRecurrentOp__(object):
__proto__ = None
type = "dynamic_recurrent"

def __init__(self):
# cache recurrent_op's proto
if self.__proto__ is None:
for op_proto in get_all_op_protos():
if op_proto.type == self.type:
self.__proto__ = op_proto

def __call__(self, *args, **kwargs):
if self.type not in args and "type" not in kwargs:
kwargs["type"] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
# create rnnop
return core.DynamicRecurrentOp.create(proto.SerializeToString())


class __CondOp__(object):
__proto__ = None
type = "cond"
Expand All @@ -242,4 +263,5 @@ def __call__(self, *args, **kwargs):

Operator = OperatorFactory() # The default global factory
RecurrentOp = __RecurrentOp__()
DynamicRecurrentOp = __DynamicRecurrentOp__()
CondOp = __CondOp__()

0 comments on commit 4a0cc85

Please sign in to comment.