Skip to content

Update the demo code and the doc of varbase.backward. #26506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions paddle/fluid/imperative/backward_strategy.h

This file was deleted.

9 changes: 5 additions & 4 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/profiler.h"

DECLARE_bool(sort_sum_gradient);

namespace paddle {
namespace imperative {

void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph) {
backward_strategy_ = strategy;
void BasicEngine::Init(VarBase* var, bool retain_graph) {
sorted_sum_gradient_ = FLAGS_sort_sum_gradient;
retain_graph_ = retain_graph;
init_node_ = var->GradVarBase()->GradNode();
var->GradVarBase()->ClearGradNode();
Expand Down Expand Up @@ -105,7 +106,7 @@ void BasicEngine::PrepareGradAccumulators(const OpBase& op) {

auto& accumulator = accumulators_[var.get()];
if (!accumulator) {
if (backward_strategy_.sorted_sum_gradient_) {
if (sorted_sum_gradient_) {
accumulator.reset(new SortedGradientAccumulator(var.get()));
} else {
accumulator.reset(new EagerGradientAccumulator(var.get()));
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/imperative/basic_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/gradient_accumulator.h"

Expand All @@ -30,8 +29,7 @@ class OpBase;

class BasicEngine : public Engine {
public:
void Init(VarBase* var, const detail::BackwardStrategy& strategy,
bool retain_graph = false);
void Init(VarBase* var, bool retain_graph = false);

void Execute() override;

Expand All @@ -46,7 +44,7 @@ class BasicEngine : public Engine {

private:
std::shared_ptr<GradOpNode> init_node_;
detail::BackwardStrategy backward_strategy_;
bool sorted_sum_gradient_;
std::unordered_map<GradOpNode*, size_t> node_deps_;
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
Expand Down
22 changes: 11 additions & 11 deletions paddle/fluid/imperative/partial_grad_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/string/string_helper.h"

DECLARE_bool(sort_sum_gradient);

namespace paddle {
namespace imperative {

Expand Down Expand Up @@ -529,8 +531,7 @@ class PartialGradTask {
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place,
const detail::BackwardStrategy &strategy, bool create_graph,
const platform::Place &place, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs);

std::vector<std::shared_ptr<VarBase>> Run();
Expand Down Expand Up @@ -577,23 +578,22 @@ class PartialGradTask {
bool retain_graph_;
bool allow_unused_;
bool only_inputs_;
detail::BackwardStrategy strategy_;
bool sorted_sum_gradient_{FLAGS_sort_sum_gradient};
};

PartialGradTask::PartialGradTask(
const std::vector<std::shared_ptr<VarBase>> &input_targets,
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs) {
const platform::Place &place, bool create_graph, bool retain_graph,
bool allow_unused, bool only_inputs) {
input_targets_ = input_targets;
place_ = place;
create_graph_ = create_graph;
retain_graph_ = retain_graph;
allow_unused_ = allow_unused;
only_inputs_ = only_inputs;
strategy_ = strategy;

PADDLE_ENFORCE_EQ(only_inputs_, true,
platform::errors::Unimplemented(
Expand Down Expand Up @@ -981,7 +981,7 @@ void PartialGradTask::PrepareInitialGradientAccumulators(const OpBase *op) {

if (!accumulator) {
accumulator.reset(new GradientAccumulationInfo(
var, strategy_.sorted_sum_gradient_, create_graph_));
var, sorted_sum_gradient_, create_graph_));
}

accumulator->IncreaseTotalRefCnt();
Expand Down Expand Up @@ -1033,11 +1033,11 @@ PartialGradEngine::PartialGradEngine(
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place, const detail::BackwardStrategy &strategy,
bool create_graph, bool retain_graph, bool allow_unused, bool only_inputs)
const platform::Place &place, bool create_graph, bool retain_graph,
bool allow_unused, bool only_inputs)
: task_(new PartialGradTask(input_targets, output_targets, output_grads,
no_grad_vars, place, strategy, create_graph,
retain_graph, allow_unused, only_inputs)) {}
no_grad_vars, place, create_graph, retain_graph,
allow_unused, only_inputs)) {}

PartialGradEngine::~PartialGradEngine() { Clear(); }

Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/imperative/partial_grad_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <memory>
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/platform/place.h"

Expand All @@ -33,8 +32,7 @@ class PartialGradEngine : public Engine {
const std::vector<std::shared_ptr<VarBase>> &output_targets,
const std::vector<std::shared_ptr<VarBase>> &output_grads,
const std::vector<std::shared_ptr<VarBase>> &no_grad_vars,
const platform::Place &place,
const detail::BackwardStrategy &strategy, bool create_graph,
const platform::Place &place, bool create_graph,
bool retain_graph, bool allow_unused, bool only_inputs);

~PartialGradEngine();
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/imperative/tests/test_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -240,9 +240,8 @@ TEST(test_tracer, test_trace_op_with_multi_device_inputs) {
framework::AttributeMap reduce_attr_map;
tracer.TraceOp("reduce_sum", reduce_in, reduce_out, reduce_attr_map,
gpu_place, true);
detail::BackwardStrategy back_st;
imperative::BasicEngine engine;
engine.Init(reduce_sum_out.get(), back_st);
engine.Init(reduce_sum_out.get());
engine.Execute();

framework::LoDTensor rlt;
Expand Down Expand Up @@ -356,9 +355,8 @@ TEST(test_tracer, test_var_without_grad_var) {
ASSERT_EQ(y_in->GradVarBase()->GradOpNum(), 0UL);
ASSERT_EQ(vout->GradVarBase()->GradOpNum(), 1UL);

detail::BackwardStrategy back_st;
imperative::BasicEngine engine;
engine.Init(vout.get(), back_st);
engine.Init(vout.get());
engine.Execute();

// check the grad
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,16 @@ DEFINE_int32(
"summary will be shown."
"If FLAGS_call_stack_level == 2, the python stack, c++ stack, and "
"error message summary will be shown.");

/**
* Debug related FLAG
* Name: sort_sum_gradient
* Since Version: 2.0.0
* Value Range: bool, default=false
* Example:
* Note: If True, gradients are summed by the reverse order of
* the forward execution sequence.
*/
DEFINE_bool(sort_sum_gradient, false,
"Sum gradients by the reverse order of "
"the forward execution sequence.");
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/global_value_getter_setter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ DECLARE_bool(enable_rpc_profiler);
DECLARE_int32(multiple_of_cupti_buffer_size);
DECLARE_bool(reader_queue_speed_test_mode);
DECLARE_int32(call_stack_level);
DECLARE_bool(sort_sum_gradient);
// device management
DECLARE_int32(paddle_num_threads);
// executor
Expand Down Expand Up @@ -340,7 +341,7 @@ static void RegisterGlobalVarGetterSetter() {
REGISTER_PUBLIC_GLOBAL_VAR(
FLAGS_eager_delete_tensor_gb, FLAGS_enable_parallel_graph,
FLAGS_allocator_strategy, FLAGS_use_system_allocator, FLAGS_check_nan_inf,
FLAGS_call_stack_level, FLAGS_cpu_deterministic,
FLAGS_call_stack_level, FLAGS_sort_sum_gradient, FLAGS_cpu_deterministic,
FLAGS_enable_rpc_profiler, FLAGS_multiple_of_cupti_buffer_size,
FLAGS_reader_queue_speed_test_mode, FLAGS_pe_profile_fname,
FLAGS_print_sub_graph_dir, FLAGS_fraction_of_cpu_memory_to_use,
Expand Down
64 changes: 7 additions & 57 deletions paddle/fluid/pybind/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ limitations under the License. */

#include "paddle/fluid/imperative/all_reduce.h"
#include "paddle/fluid/imperative/amp_auto_cast.h"
#include "paddle/fluid/imperative/backward_strategy.h"
#include "paddle/fluid/imperative/basic_engine.h"
#include "paddle/fluid/imperative/data_loader.h"
#include "paddle/fluid/imperative/layer.h"
Expand Down Expand Up @@ -507,50 +506,6 @@ void BindImperative(py::module *m_ptr) {
[]() { memory::allocation::MemoryMapFdSet::Instance().Clear(); });
#endif

py::class_<imperative::detail::BackwardStrategy> backward_strategy(
m, "BackwardStrategy", R"DOC(

BackwardStrategy is a descriptor of how to run the backward process.

**Note**:
**This API is only available in** `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ **Mode**

Attribute:
**sort_sum_gradient**:

If framework will sum the gradient by the reverse order of trace. eg. x_var ( :ref:`api_guide_Variable` ) will be the input of multiple OP such as :ref:`api_fluid_layers_scale` , this attr will decide if framework will sum gradient of `x_var` by the reverse order.

By Default: False

Examples:
.. code-block:: python

import numpy as np
import paddle.fluid as fluid

x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
x_var = fluid.dygraph.to_variable(x)
sums_inputs = []
# x_var will be multi-scales' input here
for _ in range(10):
sums_inputs.append(fluid.layers.scale(x_var))
ret2 = fluid.layers.sums(sums_inputs)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
)DOC");
backward_strategy.def(py::init())
.def_property("sort_sum_gradient",
[](const imperative::detail::BackwardStrategy &self) {
return self.sorted_sum_gradient_;
},
[](imperative::detail::BackwardStrategy &self,
bool sorted_sum_gradient) {
self.sorted_sum_gradient_ = sorted_sum_gradient;
});

m.def("start_imperative_gperf_profiler",
[]() { imperative::StartProfile(); });

Expand Down Expand Up @@ -745,21 +700,18 @@ void BindImperative(py::module *m_ptr) {
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = True
loss2.backward(backward_strategy)
loss2.backward()
print(loss2.gradient())
loss2.clear_gradient()
print("After clear {}".format(loss2.gradient()))
)DOC")
.def("_run_backward",
[](imperative::VarBase &self,
const imperative::detail::BackwardStrategy &bckst,
const imperative::Tracer &tracer, bool retain_graph) {
[](imperative::VarBase &self, const imperative::Tracer &tracer,
bool retain_graph) {
// TODO(jiabin): when we impl more backward execution we can
// select them
auto *engine = tracer.GetEngine();
engine->Init(&self, bckst, retain_graph);
engine->Init(&self, retain_graph);
VLOG(3) << "Start backward";
engine->Execute();
VLOG(3) << "Finish backward";
Expand Down Expand Up @@ -1024,13 +976,11 @@ void BindImperative(py::module *m_ptr) {
&output_targets,
const std::vector<std::shared_ptr<imperative::VarBase>> &output_grads,
const std::vector<std::shared_ptr<imperative::VarBase>> &no_grad_vars,
const platform::Place &place,
const imperative::detail::BackwardStrategy &strategy,
bool create_graph, bool retain_graph, bool allow_unused,
bool only_inputs) {
const platform::Place &place, bool create_graph, bool retain_graph,
bool allow_unused, bool only_inputs) {
imperative::PartialGradEngine engine(
input_targets, output_targets, output_grads, no_grad_vars, place,
strategy, create_graph, retain_graph, allow_unused, only_inputs);
create_graph, retain_graph, allow_unused, only_inputs);
engine.Execute();
return engine.GetResult();
},
Expand Down
1 change: 0 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@
from .framework import CUDAPlace #DEFINE_ALIAS
from .framework import CUDAPinnedPlace #DEFINE_ALIAS

from .framework import BackwardStrategy #DEFINE_ALIAS
from .framework import to_variable #DEFINE_ALIAS
from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def __bootstrap__():
'free_idle_chunk',
'free_when_no_cache_hit',
'call_stack_level',
'sort_sum_gradient',
]
if 'Darwin' not in sysstr:
read_env_flags.append('use_pinned_memory')
Expand Down
4 changes: 0 additions & 4 deletions python/paddle/fluid/dygraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@
from . import learning_rate_scheduler
from .learning_rate_scheduler import *

from . import backward_strategy
from .backward_strategy import *

from . import jit
from .jit import *

Expand Down Expand Up @@ -69,7 +66,6 @@
__all__ += parallel.__all__
__all__ += checkpoint.__all__
__all__ += learning_rate_scheduler.__all__
__all__ += backward_strategy.__all__
__all__ += jit.__all__
__all__ += io.__all__
__all__ += rnn.__all__
Expand Down
19 changes: 0 additions & 19 deletions python/paddle/fluid/dygraph/backward_strategy.py

This file was deleted.

Loading