-
Notifications
You must be signed in to change notification settings - Fork 503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add API to donate input buffer for dynamo execution #6587
Conversation
I chatted with @bdhirsh , ideally we want to figure out which input tensor can be aliased without user calling this api. There should be a way for us to retrive the aliasing information from functionization pass in aot-autograd. The context is that If we need to alias an input buffer to output buffer, we need to make sure input buffer can not be accessed by original tensor. For example
is fine because we know once the in place ops happens, a will point to a different buffer, so it is safe to reuse that buffer to something else. if it is
I can't alias because a's buffer is still needed. once aliasing happens, original buffer's value will be invalidated. if we alias in the
case. |
@@ -407,6 +409,39 @@ void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) { | |||
ResetTrimCounter(); | |||
} | |||
|
|||
bool ShouldAliasBasedOnBufferDonor() { | |||
// This env var will be updated during run time, do not use static bool here. | |||
return runtime::sys_util::GetEnvBool("XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not ideal.. I want to add a new config to coll.config
but that struct now lives on upstream...
…ll overwrite the buffer donation
09cc0cb
to
fda1c41
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left a few questions
I will merge this change to unblock the user, fix comments in a follow up pr. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for being late, but looks good to me.
This pr adds an api
torch_xla._XLAC._set_buffer_donation
which can be used to mark the buffer associated with current tensor to be donated in the next execution. This api is currently only enabled for dynamo(torch.compile) use case, ti is a no-op in LazyTensor world(since LTC already has the auto aliasing based on inplace op and dynamo's functionization pass remove all inplace ops).Example usage should be
Please note a couple things
_set_buffer_donation
is called on a tensor, but being applied to the buffer associated with this tensor. If the tensor you passed in is not a buffer(for example an intermediate tensor that has not been evulated) this api will be no-op and return false._set_buffer_donation
after first execution(torch.compile compilation triggered at first execution of the compiled function), aliasing will not change.