-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Memory optimization on Dynamic RNN #7599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
defs_can_optimize) | ||
defs_can_optimize = filter( | ||
lambda x: self._find_var(block_desc, x, is_forward).type() == core.VarDesc.VarType.LOD_TENSOR, | ||
defs_can_optimize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not combine these three filters
into one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because I find that yapf can not format lambda style. It will be
defs_can_optimize = filter(
lambda x: str(x) != "@EMPTY@" and self._has_var(block_desc, x, is_forward) and not self._find_var(block_desc, x, is_forward).persistable() and self._find_var(block_desc, x, is_forward).type() == core.VarDesc.VarType.LOD_TENSOR
self._defs[i])
It's too looooong!
can_optimize) | ||
can_optimize = filter( | ||
lambda x: self._find_var(block_desc, x, is_forward).type() == core.VarDesc.VarType.LOD_TENSOR, | ||
can_optimize) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These three filter
can also be combined. And it can share the same filter function with the filter in line 132-134.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
for index, cache_pair in enumerate(self.pool): | ||
cache_var = cache_pair[0] | ||
cache_shape = cache_pair[1] | ||
if x_shape == cache_shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can divide the optimization into three levels:
level 1: Only reusing variables with the same prod(shape)
. Perfect reusing, no memory waste or reallocating.
level 2: Reusing variables whose prod(shape)
is greater than the required prod(shape)
. There is no reallocating, but some memory may be wasted. To minimize the waste, the reused variable's prod(shape)
should be as close to the required one as possible.
Optimization of level 1 and level 2 are harmless. Enabling them is definitely better than do nothing. They shall always be applied.
level 3 (Optional): Reusing variables even if whose prod(shape)
is less than the required one. Obviously, each reusing of this level will result in a reallocating, which may slow training down. So this level is optional. To maximize the reusing efficiency, the reused variable's prod(shape)
should be as close to the required one as possible.
The whole optimization logic may be a bit complex. So I think it's better to warp the pool
as a class
and implement the reusing variable picking up logic as one of its member functions.
However, It's not necessary to complete all of these in the current PR. We can merge it first and keep refining in the future. I'm also glad to take part in the jobs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Actually, we can reuse var if the shape is smaller than the var in cache pool. But, the first dim is batch_size, which is -1 in compile time. We can not get the real size in compile time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JiayiFeng Thanks for the detailed optimization policy. Sure, we can merge this PR first and you can work on it later.
Have tested in machine translation demo which has two hidden layer inside RNN block.
The benchmark result for the first batch is following(by bytes):