Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to donate input buffer for dynamo execution #6587

Merged
merged 13 commits into from
Feb 27, 2024
Prev Previous commit
Next Next commit
only enable buffer donor aliasing in dynamo
  • Loading branch information
JackCaoG committed Feb 24, 2024
commit 3561bdf62994d2891fb2d158671b5525f9ea2cc2
10 changes: 9 additions & 1 deletion test/dynamo/test_dynamo_aliasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.core.dynamo_bridge import AliasWithBufferDonorContext


class TestBufferDonationUtil(unittest.TestCase):

Expand All @@ -13,8 +15,14 @@ def test_hash_with_buffer_donor(self):
res = torch.cos(input)
hash_no_donor = torch_xla._XLAC._get_graph_hash([res])
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
# without the AliasWithBufferDonorContext, buffer donor will be ignored,
# so we still expect the hash to be the same.
hash_with_donor = torch_xla._XLAC._get_graph_hash([res])
self.assertNotEqual(hash_no_donor, hash_with_donor)
self.assertEqual(hash_no_donor, hash_with_donor)

with AliasWithBufferDonorContext(True) as context:
hash_with_donor_and_context = torch_xla._XLAC._get_graph_hash([res])
self.assertNotEqual(hash_no_donor, hash_with_donor_and_context)


class TestBufferDonationAliasing(unittest.TestCase):
Expand Down
31 changes: 26 additions & 5 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@
ptxla_debug = int(os.environ.get('PT_XLA_DEBUG', '0')) == 1


class AliasWithBufferDonorContext(object):

def __init__(self, should_alias: bool):
self.should_alias = should_alias

def __enter__(self):
self.env_inited = 'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR' in os.environ
if self.env_inited:
self.env_saved = os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR']
os.environ[
'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = '1' if self.should_alias else '0'

def __exit__(self, exc_type, exc_val, exc_tb):
if self.env_inited:
os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = self.env_saved
else:
del os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR']


@dataclasses.dataclass
class GraphInputMatcher:
"""
Expand Down Expand Up @@ -307,9 +326,10 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
# calculate graph hash
dumb_return_handler = DumbReturnHandler(xla_args, args_and_out,
xla_args_need_update_bool)
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
if dynamo_debug:
print("graph_hash", graph_hash)
with AliasWithBufferDonorContext(True) as context:
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
if dynamo_debug:
print("graph_hash", graph_hash)

# Collect all device data nodes that is needed to compute the args_and_out
# and wrap those device data nodes inside a at::tensor(graph_input_xla_values).
Expand All @@ -328,8 +348,9 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
graph_input_tensor_ids,
graph_input_xla_values,
xla_args_tensor_ids)
# compiles and cache graph rooted at tensors in 'args_and_out'
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
with AliasWithBufferDonorContext(True) as context:
# compiles and cache graph rooted at tensors in 'args_and_out'
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])

# Restore the origional `xla_args`. Dynamo passed the real tensor as
# `xla_args`` and we performend the tracing on them. During the tracing,
Expand Down
91 changes: 51 additions & 40 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,12 @@ void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
ResetTrimCounter();
}

bool ShouldAliasBasedOnBufferDonor() {
// This env var will be updated during run time, do not use static bool here.
return runtime::sys_util::GetEnvBool("XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not ideal.. I want to add a new config to coll.config but that struct now lives on upstream...

false);
}

std::vector<size_t> GetBufferDonorIndex(
const std::vector<torch::lazy::BackendDataPtr>& parameters_data) {
std::vector<size_t> buffer_donor_indexs;
Expand Down Expand Up @@ -502,10 +508,11 @@ torch::lazy::hash_t XLAGraphExecutor::GetGraphHash(
PostOrderData po_data = RunPostOrder(ir_values, &coll);
torch::lazy::hash_t res_hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
// TODO: only compute this if buffer donor is enabled.
res_hash = torch::lazy::HashCombine(
res_hash,
torch::lazy::Hash(GetBufferDonorIndex(po_data.parameters_data)));
if (ShouldAliasBasedOnBufferDonor()) {
res_hash = torch::lazy::HashCombine(
res_hash,
torch::lazy::Hash(GetBufferDonorIndex(po_data.parameters_data)));
}
DeviceContextArena::Get()->SaveOutputShapes(res_hash,
std::move(output_shapes));
DeviceContextArena::Get()->SaveGraphAsString(res_hash, tensors,
Expand Down Expand Up @@ -1290,38 +1297,41 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
// TODO(yeounoh) aliasing is disabled for partitioned computation,
// since the current aliasing compares the unpartitioned input and output
// shapes which can lead to an incorrect aliasing pairs if sharded.
if (enable_aliasing && coll.config.sync_ltc_data &&
coll.config.force_ltc_data) {
// We can only alias at the step barrier, when force_ltc_data is true.
// Consider the case:
// 1. Tensor A(DEVICE_DATA)
// 2. Tensor B = A + 0.9
// 3. A += 0.4
// If we activate aliasing for A's graph, and we do:
// print(A)
// print(A)
// The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the
// second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4, which
// will lead to incorrect results.
// We cannot normally turn A's state into DEVICE_DATA, as if any of the
// sources is a view, this will not lead to correct results (as A's value
// taken at different times need to reflect view source changes):
// 1. Tensor A = some_graph_with_view_source(V)
// 2. print(A)
// 3. V += 1
// 4. print(A)
// The second print should reflect the new value due to V's changes.
// Also in the first example, unless we are doing a step barrier and hence
// include all live tensors, if the B value is not part of the graph, it
// will later fetch the new value of A, which is incorrect.
// But, when we issue a step barrier (force_ltc_data == true) we have to
// turn everything into DEVICE_DATA, so we can activate aliasing.
std::cerr << "build input output aliasing\n";
input_output_alias_pair =
BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx);
} else if (enable_aliasing) {
std::cerr << "call SetBufferDonors\n";
buffer_donor_indices = SetBufferDonors(&lowering_ctx);
if (enable_aliasing) {
if (coll.config.sync_ltc_data && coll.config.force_ltc_data) {
// We can only alias at the step barrier, when force_ltc_data is true.
// Consider the case:
// 1. Tensor A(DEVICE_DATA)
// 2. Tensor B = A + 0.9
// 3. A += 0.4
// If we activate aliasing for A's graph, and we do:
// print(A)
// print(A)
// The first print will update DEVICE_DATA' with DEVICE_DATA+0.4, and the
// second print will again update DEVICE_DATA" with DEVICE_DATA'+0.4,
// which will lead to incorrect results. We cannot normally turn A's state
// into DEVICE_DATA, as if any of the sources is a view, this will not
// lead to correct results (as A's value taken at different times need to
// reflect view source changes):
// 1. Tensor A = some_graph_with_view_source(V)
// 2. print(A)
// 3. V += 1
// 4. print(A)
// The second print should reflect the new value due to V's changes.
// Also in the first example, unless we are doing a step barrier and hence
// include all live tensors, if the B value is not part of the graph, it
// will later fetch the new value of A, which is incorrect.
// But, when we issue a step barrier (force_ltc_data == true) we have to
// turn everything into DEVICE_DATA, so we can activate aliasing.
std::cerr << "build input output aliasing\n";
input_output_alias_pair =
BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx);
} else if (ShouldAliasBasedOnBufferDonor()) {
// only alias based on buffer donor if LTC can't auto infer the input
// output aliasing.
std::cerr << "call SetBufferDonors\n";
buffer_donor_indices = SetBufferDonors(&lowering_ctx);
}
}

xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
Expand Down Expand Up @@ -1417,10 +1427,11 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
PostOrderData po_data = RunPostOrder(ir_values, &coll);
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
// TODO: only include this if env var is enabled.
coll.hash = torch::lazy::HashCombine(
coll.hash,
torch::lazy::Hash(GetBufferDonorIndex(po_data.parameters_data)));
if (ShouldAliasBasedOnBufferDonor()) {
coll.hash = torch::lazy::HashCombine(
coll.hash,
torch::lazy::Hash(GetBufferDonorIndex(po_data.parameters_data)));
}

DebugUtil::SaveGraphHash(coll.hash);
TF_VLOG(4) << "Parameter sequence graph hash "
Expand Down