Skip to content

Memory optimization on Dynamic RNN #7599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 23, 2018
Merged

Conversation

QiJune
Copy link
Member

@QiJune QiJune commented Jan 17, 2018

Have tested in machine translation demo which has two hidden layer inside RNN block.

The benchmark result for the first batch is following(by bytes):

Model Before After Saving
Machine translation 525144064 490860544 6.53%

@QiJune QiJune changed the title [WIP]Memory optimization on Dynamic RNN Memory optimization on Dynamic RNN Jan 23, 2018
defs_can_optimize)
defs_can_optimize = filter(
lambda x: self._find_var(block_desc, x, is_forward).type() == core.VarDesc.VarType.LOD_TENSOR,
defs_can_optimize)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not combine these three filters into one?

Copy link
Member Author

@QiJune QiJune Jan 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I find that yapf can not format lambda style. It will be

defs_can_optimize = filter(
       lambda x: str(x) != "@EMPTY@" and self._has_var(block_desc, x, is_forward) and   not self._find_var(block_desc, x, is_forward).persistable() and self._find_var(block_desc, x, is_forward).type() == core.VarDesc.VarType.LOD_TENSOR
       self._defs[i])

It's too looooong!

can_optimize)
can_optimize = filter(
lambda x: self._find_var(block_desc, x, is_forward).type() == core.VarDesc.VarType.LOD_TENSOR,
can_optimize)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These three filter can also be combined. And it can share the same filter function with the filter in line 132-134.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for index, cache_pair in enumerate(self.pool):
cache_var = cache_pair[0]
cache_shape = cache_pair[1]
if x_shape == cache_shape:
Copy link
Collaborator

@JiayiFeng JiayiFeng Jan 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can divide the optimization into three levels:

level 1: Only reusing variables with the same prod(shape). Perfect reusing, no memory waste or reallocating.

level 2: Reusing variables whose prod(shape) is greater than the required prod(shape). There is no reallocating, but some memory may be wasted. To minimize the waste, the reused variable's prod(shape) should be as close to the required one as possible.

Optimization of level 1 and level 2 are harmless. Enabling them is definitely better than do nothing. They shall always be applied.

level 3 (Optional): Reusing variables even if whose prod(shape) is less than the required one. Obviously, each reusing of this level will result in a reallocating, which may slow training down. So this level is optional. To maximize the reusing efficiency, the reused variable's prod(shape) should be as close to the required one as possible.

The whole optimization logic may be a bit complex. So I think it's better to warp the pool as a class and implement the reusing variable picking up logic as one of its member functions.

However, It's not necessary to complete all of these in the current PR. We can merge it first and keep refining in the future. I'm also glad to take part in the jobs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Actually, we can reuse var if the shape is smaller than the var in cache pool. But, the first dim is batch_size, which is -1 in compile time. We can not get the real size in compile time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JiayiFeng Thanks for the detailed optimization policy. Sure, we can merge this PR first and you can work on it later.

@QiJune QiJune merged commit d76fcb6 into PaddlePaddle:develop Jan 23, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants