[Dy2St] Cleanup no need buffer inputs in grad node #69043
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
动转静 grad node 会 hold 住前向输入,但对于 no need buffer 的 Tensor 而言,只需要 meta,不需要 holder,因此对于这些 no need buffer Tensor,反向 grad node hold 的 Tensor 不再是原来的 Tensor,而是 copy 后不持有 holder 的 Tensor(holder 被 Move 走了)
如图所示,图中只表示了反向 no need buffer 的 Tensor,长度表示生命周期,
x -> y
表示 x 持有 y 的引用,此时 y 的生命周期必然大于等于 x 的生命周期由于这些 Tensor 还没释放(受 Python 端调度,Python 端 PyObject 引用计数到 0),就被设置到反向 GradNode,所以直到反向结束才真正释放
因此本 PR 对于这些 no need buffer Tensor,copy 了一个 Tensor,持有一个 copy 的 DenseTensor,没有 holder,这就确保了反向持有的是有 meta 但没 holder 的 Tensor
PT 同样有该问题,但没暴露,是因为 PT 的输入会多一个 cast,导致 no need buffer 的是输入 cast 后的 value 而不是输入,不是输入是没有这个问题的
修复前后显存如下
Max allocated 已经明显低于 PT(PT 没有修这个问题)
不过值得注意的是,因为 x 是 ad func 的输入,受到 Python 端 GC 调度,ad func 是不能擅自删掉它的,否则就可能导致后面使用 x 时出问题,这会导致动转静下,输入总是在整个子图执行完才 GC,而不能随 OP 执行完释放,导致对比动态图峰值显存会高一些,SOT 因为有多个子图,会有一些中间变量作为子图输入,这个问题会更加凸显一些,但这个问题目前是比较无解的
PCard-66972