Skip to content

Conversation

@pangyoki
Copy link
Contributor

@pangyoki pangyoki commented Jan 15, 2021

PR types

New features

PR changes

APIs

Describe

PR #30103 的备份PR

动态图支持Inplace策略:不新建输出VarBase,直接将输入VarBase传递给输出。与静态图已添加Inplace策略的op对齐。为与该策略相关的OP新增python API接口。

  • 支持的OP
    与静态图已支持Inplace策略的OP对齐,一共有38个。在这个PR中,为squeeze2unsqueeze2reshape2scatter
    elurelusoftmaxtanh 8个OP添加了Inplace策略。其它OP后续会添加Inplace支持。

  • API接口
    为需要使用Inplace策略的OP新增python API接口,API名称为现有API名称加上下划线_,用户可选择是否使用这类API组网。
    对上述8个OP,对应添加python API接口:squeeze_unsqueeze_reshape_scatter_
    elu_relu_softmax_tanh_

  • 生成的op function
    在op_function_generator层面添加动态图Inplace机制,以squeeze2 op为例,添加Inplace策略后,生成的op function为imperative_squeeze2_
    主要改变为:

    • 检查需要做inplace的var是否为stop_gradient=False的叶子节点,如果是,则会报错。
    • inplace var的 inplace_version 增加1。
    • 输入Var直接传递给输出。
    std::tuple<std::shared_ptr<imperative::VarBase>,std::shared_ptr<imperative::VarBase>> imperative_squeeze2_(const py::handle& X_, const py::args& args)
    {
      
      auto X = CastPyHandleToVarBase("squeeze2", "X", 0, X_, false);
      framework::AttributeMap attrs;
      ConstructAttrMapFromPyArgs("squeeze2", 1, &attrs, args);
      {
        py::gil_scoped_release release;
        auto tracer = imperative::GetCurrentTracer();
        
        PADDLE_ENFORCE_EQ(
          X->IsLeaf() && !X->OverridedStopGradient(), false,
          platform::errors::InvalidArgument("Leaf Var (%s) that doesn't stop gradient can't use inplace strategy.", X->Name()));
        X->BumpInplaceVersion();
        VLOG(3) << "Var(" << X->Name() << ") uses Inplace Strategy.";
    
        imperative::NameVarBaseMap outs = {{"Out", {X}},{"XShape", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase(tracer->GenerateUniqueName()))}}};
        imperative::NameVarBaseMap ins = {{"X", {X}}};
        
        tracer->TraceOp("squeeze2", ins, outs, attrs, {{"X", "Out"}});
        return std::make_tuple(outs["Out"][0],outs["XShape"][0]); 
      }   
    }
    
  • squeeze_() API使用示例

    import paddle
    
    x = paddle.rand([5, 1, 10])
    output = paddle.squeeze_(x, axis=1)
    
    print(x.shape)  # [5, 10]
    print(output.shape)  # [5, 10]
    print(id(x) == id(output))  # True
    
  • Inplace报错

    • stop_gradient=False的叶子节点
    import paddle
    
    a = paddle.rand([2,3,1])
    a.stop_gradient = False
    a.squeeze_()
    # ValueError: (InvalidArgument) Leaf Var (dygraph_tmp_2) that doesn't stop gradient can't use inplace strategy.
    
    import paddle
    
    a = paddle.rand([2,3,1])
    a.stop_gradient = False
    
    b = a ** 2
    c = b ** 2
    b.squeeze_()
    
    c.sum().backward()
    # RuntimeError: (PermissionDenied) Tensor 'dygraph_tmp_2' used in gradient computation in grad op 'elementwise_pow_grad' has been modified by an inplace operation. Its version is 1 but the expected version is 0. Please fix your code to void calling an inplace operator after using the Tensor which will used in gradient computation
    
  • 实现要点

    • op function generator中加入动态图使用Inplace策略的逻辑
    • 反向op执行时不会用Inplace策略,对inplace反向op,需要生成临时var
    • inplace反向构造网络时会出现递归问题:修改GradOpMaker中SetInput和SetOutput逻辑
    • inplace var会出现梯度聚合问题:将basic_engine中accumulator_的key加入grad_node

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant