Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to donate input buffer for dynamo execution #6587

Merged
merged 13 commits into from
Feb 27, 2024

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Feb 22, 2024

This pr adds an api torch_xla._XLAC._set_buffer_donation which can be used to mark the buffer associated with current tensor to be donated in the next execution. This api is currently only enabled for dynamo(torch.compile) use case, ti is a no-op in LazyTensor world(since LTC already has the auto aliasing based on inplace op and dynamo's functionization pass remove all inplace ops).

Example usage should be

  def dummy_inplace_add(self, input):
    input += 1
    return

  def test_manual_buffer_donation(self):
    device = xm.xla_device()
    input = torch.randn(5, 5).to(device)
    dummy_inplace_add_compiled = torch.compile(
        self.dummy_inplace_add, backend='openxla')

    met.clear_all()
    # input is a device_data, we should be able to set the buffer donation field.
    self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
    # make sure buffer donation setting is correctly updated
    self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
	
	for _ in range(100):
      # You don't need to keep calling this function if you function does not cause dynamo recompilation.
      # check below for the reason.
      self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
      dummy_inplace_add_compiled(input)

Please note a couple things

  1. _set_buffer_donation is called on a tensor, but being applied to the buffer associated with this tensor. If the tensor you passed in is not a buffer(for example an intermediate tensor that has not been evulated) this api will be no-op and return false.
  2. Buffer aliasing is being set up during compilation time. Torch Dynamo also does not track this field, so even if you change _set_buffer_donation after first execution(torch.compile compilation triggered at first execution of the compiled function), aliasing will not change.

@JackCaoG
Copy link
Collaborator Author

I chatted with @bdhirsh , ideally we want to figure out which input tensor can be aliased without user calling this api. There should be a way for us to retrive the aliasing information from functionization pass in aot-autograd.

The context is that If we need to alias an input buffer to output buffer, we need to make sure input buffer can not be accessed by original tensor. For example

a += 1

is fine because we know once the in place ops happens, a will point to a different buffer, so it is safe to reuse that buffer to something else.

if it is

b = a + 1

I can't alias because a's buffer is still needed. once aliasing happens, original buffer's value will be invalidated. if we alias in the

b = a + 1

case. a's value will be incorrect after this computation. Dynamo comes with functionization which will remove all inplace ops from the fx graph passed down, which is a problem in this case.

@JackCaoG JackCaoG marked this pull request as ready for review February 24, 2024 00:27
@JackCaoG JackCaoG changed the title [WIP]Add API to donate input buffer for dynamo execution Add API to donate input buffer for dynamo execution Feb 24, 2024
@@ -407,6 +409,39 @@ void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
ResetTrimCounter();
}

bool ShouldAliasBasedOnBufferDonor() {
// This env var will be updated during run time, do not use static bool here.
return runtime::sys_util::GetEnvBool("XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not ideal.. I want to add a new config to coll.config but that struct now lives on upstream...

@JackCaoG JackCaoG force-pushed the JackCaoG/dynamo_aliasing_2 branch from 09cc0cb to fda1c41 Compare February 24, 2024 00:45
Copy link
Collaborator

@lsy323 lsy323 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, left a few questions

@JackCaoG
Copy link
Collaborator Author

I will merge this change to unblock the user, fix comments in a follow up pr.

@JackCaoG JackCaoG merged commit 3e2a23c into master Feb 27, 2024
18 checks passed
Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for being late, but looks good to me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants