Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions paddle/operators/dropout_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {

auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
if (ctx->Attrs().Get<bool>("is_training") == 1) {
if (ctx->Attrs().Get<bool>("is_training") == true) {
ctx->SetOutputDim("Mask", x_dims);
}
ctx->ShareLoD("X", /*->*/ "Out");
Expand All @@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
DropoutOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f);
AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
Expand All @@ -69,16 +69,16 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
"GradOp is only callable when is_training is true");

PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null.");

PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1);
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims,
Expand Down
4 changes: 2 additions & 2 deletions paddle/operators/dropout_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
auto* y = context.Output<Tensor>("Out");
const auto* x_data = x->data<T>();
auto* y_data = y->mutable_data<T>(context.GetPlace());
AttrType dropout_prob = context.Attr<AttrType>("dropout_prob");
float dropout_prob = context.Attr<float>("dropout_prob");

if (context.Attr<bool>("is_training")) {
auto* mask = context.Output<Tensor>("Mask");
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
int seed = context.Attr<int>("seed");
std::minstd_rand engine;
engine.seed(seed);
std::uniform_real_distribution<AttrType> dist(0, 1);
std::uniform_real_distribution<float> dist(0, 1);
size_t size = framework::product(mask->dims());
for (size_t i = 0; i < size; ++i) {
if (dist(engine) < dropout_prob) {
Expand Down
28 changes: 23 additions & 5 deletions python/paddle/v2/framework/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,28 @@ def _convert_(name):

def _create_op_func_(op_type):
op_proto = OpProtoHolder.instance().get_op_proto(op_type)
if len(op_proto.outputs) != 1:
not_intermediate_outputs = \
filter(lambda output: not output.intermediate, op_proto.outputs)
intermediate_outputs = \
filter(lambda output: output.intermediate, op_proto.outputs)

if len(not_intermediate_outputs) != 1:
raise ValueError(
"Only one output operator can be automatically generated")
"Only one not intermediate output operator can be automatically generated"
)

if op_proto.outputs[0].duplicable:
if not_intermediate_outputs[0].duplicable:
raise ValueError(
"Only not duplicable op can be automatically generated")

o_name = op_proto.outputs[0].name
for output in intermediate_outputs:
if output.duplicable:
raise ValueError(
"Only when all intermediate ops are not duplicable, "
"this op can be automatically generated")

o_name = not_intermediate_outputs[0].name
intermediate_output_names = [output.name for output in intermediate_outputs]

def func(**kwargs):
helper = LayerHelper(op_type, **kwargs)
Expand All @@ -128,9 +141,13 @@ def func(**kwargs):
"operator {0} must input same dtype".format(op_type))
inputs[ipt.name] = val

outputs = dict()
out = helper.create_tmp_variable(dtype=dtype)
outputs[o_name] = [out]
for name in intermediate_output_names:
outputs[name] = [helper.create_tmp_variable(dtype=dtype)]
helper.append_op(
type=op_type, inputs=inputs, outputs={o_name: [out]}, attrs=kwargs)
type=op_type, inputs=inputs, outputs=outputs, attrs=kwargs)
return out

func.__name__ = op_type
Expand All @@ -141,6 +158,7 @@ def func(**kwargs):

_create_op_func_('mean')
_create_op_func_('mul')
_create_op_func_('dropout')


def concat(input, axis, program=None, init_program=None):
Expand Down