Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[scan] Reduce memory usage #8562

Merged
merged 1 commit into from
Jan 15, 2025
Merged

[scan] Reduce memory usage #8562

merged 1 commit into from
Jan 15, 2025

Conversation

tengyifei
Copy link
Collaborator

This change optimizes the memory usage of scan by turning some copies into aliases. Specifically, it optimizes intermediate activations which are aliases of inputs. This is a common occurrence.

In principle, we could have forward return all the intermediate activations, including those that are aliases to an input tensor. However, those inputs will then be duplicated as part of the output of a scan call, because we want to save all activations during the forward pass of a scan. The XLA compiler can't optimize away this duplication likely because they're behind a DynamicSlice + DynamicUpdateSlice, so we end up doubling the memory usage from those inputs.

To reduce memory usage, we can have forward return the activations that don't alias to inputs, called partial_activations. The autograd implementation of scan will call alias_input to add back activations that are aliases of input tensors outside of a scan, turning the partial activations back to full activations.

@tengyifei tengyifei requested a review from qihqi January 14, 2025 00:20
@tengyifei tengyifei marked this pull request as ready for review January 14, 2025 00:31
torch_xla/experimental/scan.py Outdated Show resolved Hide resolved
This change optimizes the memory usage of scan by turning some copies into
aliases. Specifically, it optimizes intermediate activations which are
aliases of inputs. This is a common occurrence.

In principle, we could have `forward` return all the intermediate activations,
including those that are aliases to an input tensor. However, those inputs will
then be duplicated as part of the output of a `scan` call, because we want to
save all activations during the forward pass of a `scan`. The XLA compiler can't
optimize away this duplication likely because they're behind a DynamicSlice +
DynamicUpdateSlice, so we end up doubling the memory usage from those inputs.

To reduce memory usage, we can have `forward` return the activations that
don't alias to inputs, called `partial_activations`. The autograd implementation
of `scan` will call `alias_input` to add back activations that are aliases
of input tensors outside of a scan, turning the partial activations back to
full activations.
@tengyifei tengyifei force-pushed the scan/activation-alias-input branch from fc707b8 to 385797d Compare January 15, 2025 20:00
@tengyifei tengyifei merged commit dc76879 into master Jan 15, 2025
12 checks passed
qihqi pushed a commit that referenced this pull request Jan 16, 2025
@tengyifei tengyifei deleted the scan/activation-alias-input branch January 22, 2025 21:34
tengyifei added a commit that referenced this pull request Jan 22, 2025
This was referenced Jan 22, 2025
tengyifei added a commit that referenced this pull request Jan 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants