Skip to content

Memory optimization on Dynamic RNN #7599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions paddle/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ std::vector<VarDesc *> BlockDesc::AllVars() const {

OpDesc *BlockDesc::AppendOp() {
need_update_ = true;
ops_.emplace_back(new OpDesc());
ops_.emplace_back(new OpDesc(this));
return ops_.back().get();
}

Expand All @@ -86,7 +86,7 @@ void BlockDesc::AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc) {

OpDesc *BlockDesc::PrependOp() {
need_update_ = true;
ops_.emplace_front(new OpDesc());
ops_.emplace_front(new OpDesc(this));
return ops_.front().get();
}

Expand Down Expand Up @@ -153,7 +153,7 @@ BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
vars_[var_desc.name()].reset(new VarDesc(var_desc));
}
for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDesc(op_desc, prog));
ops_.emplace_back(new OpDesc(op_desc, prog, this));
}
}

Expand All @@ -162,7 +162,7 @@ BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
: prog_(prog), desc_(desc) {
need_update_ = true;
for (auto &op : other.ops_) {
ops_.emplace_back(new OpDesc(*op));
ops_.emplace_back(new OpDesc(*op, this));
}

for (auto &it : other.vars_) {
Expand Down
3 changes: 2 additions & 1 deletion paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void OpDesc::CopyFrom(const OpDesc &op_desc) {
need_update_ = true;
}

OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block)
: desc_(desc), need_update_(false) {
// restore inputs_
int input_size = desc_.inputs_size();
Expand Down Expand Up @@ -131,6 +131,7 @@ OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
attrs_[attr_name] = prog->MutableBlock(bid);
}
}
this->block_ = block;
}

proto::OpDesc *OpDesc::Proto() {
Expand Down
15 changes: 13 additions & 2 deletions paddle/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,21 @@ namespace framework {

class BlockDesc;
class ProgramDesc;

class OpDesc {
public:
OpDesc() {}

OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);

OpDesc(const proto::OpDesc &desc, ProgramDesc *prog);
OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block);

explicit OpDesc(BlockDesc *block) : block_(block) {}

OpDesc(const OpDesc &other, BlockDesc *block) {
*this = other;
block_ = block;
}

void CopyFrom(const OpDesc &op_desc);

Expand Down Expand Up @@ -117,6 +123,10 @@ class OpDesc {

void Flush();

BlockDesc *Block() { return this->block_; }

void SetBlock(BlockDesc *block) { this->block_ = block; }

private:
template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
Expand All @@ -129,6 +139,7 @@ class OpDesc {
}

proto::OpDesc desc_;
BlockDesc *block_; // not_own
// input arg name => input variable names
VariableNameMap inputs_;
// output arg name => output variable names
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class VarDesc {

std::string Name() const { return desc_.name(); }

void SetName(std::string name) { desc_.set_name(name); }

void SetShape(const std::vector<int64_t> &dims);

void SetDataType(proto::DataType data_type);
Expand Down
4 changes: 3 additions & 1 deletion paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ void BindVarDsec(py::module &m) {
return name;
},
py::return_value_policy::reference)
.def("set_name", &VarDesc::SetName)
.def("set_shape", &VarDesc::SetShape)
.def("set_dtype", &VarDesc::SetDataType)
.def("shape", &VarDesc::Shape, py::return_value_policy::reference)
Expand Down Expand Up @@ -280,7 +281,8 @@ void BindOpDesc(py::module &m) {
.def("check_attrs", &OpDesc::CheckAttrs)
.def("infer_shape", &OpDesc::InferShape)
.def("infer_var_type", &OpDesc::InferVarType)
.def("serialize_to_string", SerializeMessage<OpDesc>);
.def("serialize_to_string", SerializeMessage<OpDesc>)
.def("block", &OpDesc::Block, py::return_value_policy::reference);
}

} // namespace pybind
Expand Down
180 changes: 130 additions & 50 deletions python/paddle/v2/fluid/memory_optimization_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@


class ControlFlowGraph(object):
def __init__(self, Program):
def __init__(self, Program, ops, forward_num):
self._program = Program
self._succesors = defaultdict(set)
self._presucessors = defaultdict(set)
self._ops = ops
self._forward_num = forward_num
self._successors = defaultdict(set)
self._presuccessors = defaultdict(set)
self._uses = defaultdict(set)
self._defs = defaultdict(set)
self._live_in = defaultdict(set)
Expand All @@ -45,25 +47,16 @@ def _add_connections(self, connections):
self._add(node1, node2)

def _add(self, node1, node2):
self._succesors[node1].add(node2)
self._presucessors[node2].add(node1)
self._successors[node1].add(node2)
self._presuccessors[node2].add(node1)

def _build_graph(self):
program_desc = self._program.get_desc()
block_size = program_desc.num_blocks()

# TODO(qijun) handle Program with if/while operators
self.global_block_desc = program_desc.block(0)
self.op_size = self.global_block_desc.op_size()

self.op_size = len(self._ops)
op_node_connections = [(i, i + 1) for i in range(self.op_size - 1)]
self._add_connections(op_node_connections)

self.ops = [self.global_block_desc.op(i) for i in range(self.op_size)]

for i in range(self.op_size):
self._uses[i].update(self.ops[i].input_arg_names())
self._defs[i].update(self.ops[i].output_arg_names())
self._uses[i].update(self._ops[i].input_arg_names())
self._defs[i].update(self._ops[i].output_arg_names())

def _update_graph(self, old_name, new_name, begin_idx=0):
for i in range(begin_idx, self.op_size):
Expand Down Expand Up @@ -103,7 +96,7 @@ def _dataflow_analyze(self):
live_out[i] = set(self._live_out[i])
self._live_in[i] = self._uses[i] | (
self._live_out[i] - self._defs[i])
for s in self._succesors[i]:
for s in self._successors[i]:
self._live_out[i] |= self._live_in[s]

if self._reach_fixed_point(live_in, live_out):
Expand All @@ -113,60 +106,147 @@ def _get_diff(self, a, b):
u = a & b
return a - u, b - u

def _has_var(self, block_desc, var_name, is_forward):
if is_forward:
return block_desc.has_var(str(var_name))
else:
return block_desc.has_var_recursive(str(var_name))

def _find_var(self, block_desc, var_name, is_forward):
if is_forward:
return block_desc.find_var(str(var_name))
else:
return block_desc.find_var_recursive(str(var_name))

def memory_optimize(self):
def check_var_validity(block_desc, x, is_forward):
if str(x) == "@EMPTY@":
return False
if not self._has_var(block_desc, x, is_forward):
return False
if self._find_var(block_desc, x, is_forward).persistable():
return False
if self._find_var(
block_desc, x,
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
return False
return True

self._build_graph()
self._dataflow_analyze()
self.pool = []
for i in range(self.op_size):
op = self._ops[i]
if op.type() == "while" or op.type() == "while_grad":
continue
block_desc = op.block()
is_forward = i < self._forward_num
if self.pool:
out_pair = [(x, self.global_block_desc.var(str(x)).shape())
for x in self._defs[i]]
defs_can_optimize = filter(
lambda x: check_var_validity(block_desc, x, is_forward),
self._defs[i])
out_pair = [
(x, self._find_var(block_desc, x, is_forward).shape())
for x in defs_can_optimize
]
for x, x_shape in out_pair:
if not self.global_block_desc.var(str(x)).persistable():
for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0]
cache_shape = cache_pair[1]
if x_shape == cache_shape:
x_dtype = self.global_block_desc.var(str(
x)).dtype()
cache_dtype = self.global_block_desc.var(
str(cache_var)).dtype()
for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0]
cache_shape = cache_pair[1]
if x_shape == cache_shape:
Copy link
Collaborator

@JiayiFeng JiayiFeng Jan 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can divide the optimization into three levels:

level 1: Only reusing variables with the same prod(shape). Perfect reusing, no memory waste or reallocating.

level 2: Reusing variables whose prod(shape) is greater than the required prod(shape). There is no reallocating, but some memory may be wasted. To minimize the waste, the reused variable's prod(shape) should be as close to the required one as possible.

Optimization of level 1 and level 2 are harmless. Enabling them is definitely better than do nothing. They shall always be applied.

level 3 (Optional): Reusing variables even if whose prod(shape) is less than the required one. Obviously, each reusing of this level will result in a reallocating, which may slow training down. So this level is optional. To maximize the reusing efficiency, the reused variable's prod(shape) should be as close to the required one as possible.

The whole optimization logic may be a bit complex. So I think it's better to warp the pool as a class and implement the reusing variable picking up logic as one of its member functions.

However, It's not necessary to complete all of these in the current PR. We can merge it first and keep refining in the future. I'm also glad to take part in the jobs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Actually, we can reuse var if the shape is smaller than the var in cache pool. But, the first dim is batch_size, which is -1 in compile time. We can not get the real size in compile time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JiayiFeng Thanks for the detailed optimization policy. Sure, we can merge this PR first and you can work on it later.

if self._has_var(block_desc, cache_var, is_forward):
x_dtype = self._find_var(block_desc, x,
is_forward).dtype()
cache_dtype = self._find_var(
block_desc, cache_var, is_forward).dtype()
# TODO(qijun): actually, we should compare dtype_to_size[x_dtype]
# and dtype_to_size[cache_dtype]
if x_dtype == cache_dtype:
print(
("Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
"var shape is %s ") %
(index, x, cache_var, str(cache_shape)))
print(("Hit Cache !!!! cache pool index "
"is %d, var name is %s, "
"cached var name is %s, "
"var shape is %s ") %
(index, x, cache_var,
str(cache_shape)))
self.pool.pop(index)
if x == cache_var:
break
_rename_arg_(
self.ops, x, cache_var, begin_idx=i)
self._program.current_block().var(str(
x)).desc = self.global_block_desc.var(
str(cache_var))
self._ops, x, cache_var, begin_idx=i)
self._program.block(block_desc.id).var(
str(x)).desc = self._find_var(
block_desc, cache_var, is_forward)
self._update_graph(
x, cache_var, begin_idx=i)
break

in_diff, out_diff = self._get_diff(self._live_in[i],
self._live_out[i])
can_optimize = filter(
lambda x: not self.global_block_desc.var(str(x)).persistable(),
lambda x: check_var_validity(block_desc, x, is_forward),
in_diff)
if can_optimize:
for var_name in can_optimize:
self.pool.append(
(var_name,
self.global_block_desc.var(str(var_name)).shape()))

def get_program(self):
return self._program
self.pool.append((var_name, self._find_var(
block_desc, var_name, is_forward).shape()))


def get_cfgs(input_program):
ops_list = []
pdesc = input_program.get_desc()
block_desc = pdesc.block(0)
op_size = block_desc.op_size()
# Get global block ops
ops_list.append(([block_desc.op(i) for i in range(op_size)], op_size))

while_sub_block_ids = []
while_grad_sub_block_ids = []
while_pair = []

for i in range(op_size):
op = block_desc.op(i)
if op.type() == "while":
while_sub_block_ids.append(op.attr("sub_block").id)
elif op.type() == "while_grad":
while_grad_sub_block_ids.append(op.attr("sub_block").id)

# Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent
if parent_id in while_sub_block_ids:
while_pair.append((parent_id, grad_id))
while_sub_block_ids.remove(parent_id)

# Get while/while_grad block ops
for parent_id, grad_id in while_pair:
while_block_ops = []
while_block = pdesc.block(parent_id)
while_block_op_size = while_block.op_size()
for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i))

while_grad_block = pdesc.block(grad_id)
while_grad_block_op_size = while_grad_block.op_size()
for i in range(while_grad_block_op_size):
while_block_ops.append(while_grad_block.op(i))

ops_list.append((while_block_ops, while_block_op_size))

# Process rest while block ops
for parent_id in while_sub_block_ids:
while_block_ops = []
while_block = pdesc.block(parent_id)
while_block_op_size = while_block.op_size()
for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i))

ops_list.append((while_block_ops, while_block_op_size))

cfgs = [ControlFlowGraph(input_program, i, j) for i, j in ops_list]
return cfgs


def memory_optimize(input_program):
graph = ControlFlowGraph(input_program)
graph.memory_optimize()
result_program = graph.get_program()
return result_program
cfgs = get_cfgs(input_program)
for cfg in cfgs:
cfg.memory_optimize()
Loading