Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to donate input buffer for dynamo execution #6587

Merged
merged 13 commits into from
Feb 27, 2024
Prev Previous commit
Next Next commit
make sure compilation hash tracks buffer donor index
  • Loading branch information
JackCaoG committed Feb 24, 2024
commit 0b6c3a8107799e73a4ab7a3885f0f07127d17c1b
11 changes: 11 additions & 0 deletions test/dynamo/test_dynamo_aliasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

class TestBufferDonationUtil(unittest.TestCase):

def test_hash_with_buffer_donor(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
res = torch.cos(input)
hash_no_donor = torch_xla._XLAC._get_graph_hash([res])
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
hash_with_donor = torch_xla._XLAC._get_graph_hash([res])
self.assertNotEqual(hash_no_donor, hash_with_donor)


class TestBufferDonationAliasing(unittest.TestCase):

Expand Down
57 changes: 36 additions & 21 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,33 @@ void XLAGraphExecutor::MarkStep(const torch::lazy::BackendDevice& device) {
ResetTrimCounter();
}

std::vector<size_t> GetBufferDonorIndex(
const std::vector<torch::lazy::BackendDataPtr>& parameters_data) {
std::vector<size_t> buffer_donor_indexs;
for (size_t i = 0; i < parameters_data.size(); ++i) {
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
parameters_data[i]);
if (data->should_donate_buffer()) {
buffer_donor_indexs.push_back(i);
}
}
return buffer_donor_indexs;
}

std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
LoweringContext* lowering_ctx) {
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
lowering_ctx->GetParametersData();
std::vector<size_t> buffer_donor_indexs =
GetBufferDonorIndex(parameters_data);
for (size_t i : buffer_donor_indexs) {
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
return buffer_donor_indexs;
}

void XLAGraphExecutor::WaitDeviceOps(absl::Span<const std::string> devices) {
std::set<torch::lazy::BackendDevice> wait_devices;
if (!devices.empty()) {
Expand Down Expand Up @@ -475,6 +502,10 @@ torch::lazy::hash_t XLAGraphExecutor::GetGraphHash(
PostOrderData po_data = RunPostOrder(ir_values, &coll);
torch::lazy::hash_t res_hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
// TODO: only compute this if buffer donor is enabled.
res_hash = torch::lazy::HashCombine(
res_hash,
torch::lazy::Hash(GetBufferDonorIndex(po_data.parameters_data)));
DeviceContextArena::Get()->SaveOutputShapes(res_hash,
std::move(output_shapes));
DeviceContextArena::Get()->SaveGraphAsString(res_hash, tensors,
Expand Down Expand Up @@ -1224,27 +1255,6 @@ XLAGraphExecutor::BuildInputOutputAliases(
return input_output_alias_pair;
}

std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
LoweringContext* lowering_ctx) {
std::vector<size_t> buffer_donor_indexs;
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
lowering_ctx->GetParametersData();
for (size_t i = 0; i < parameters_data.size(); ++i) {
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
parameters_data[i]);
if (data->should_donate_buffer()) {
buffer_donor_indexs.push_back(i);
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
cerr << "add buffer donor at index " << i << "\n";
} else {
cerr << "skip buffer donor at index" << i << "\n";
}
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
return buffer_donor_indexs;
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
const std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices, const SyncTensorCollection& coll,
Expand Down Expand Up @@ -1407,6 +1417,11 @@ XLAGraphExecutor::SyncTensorsGraphInternal(
PostOrderData po_data = RunPostOrder(ir_values, &coll);
coll.hash = torch::lazy::HashCombine(
coll.hash, torch::lazy::Hash(po_data.parameter_sequence));
// TODO: only include this if env var is enabled.
coll.hash = torch::lazy::HashCombine(
coll.hash,
torch::lazy::Hash(GetBufferDonorIndex(po_data.parameters_data)));

DebugUtil::SaveGraphHash(coll.hash);
TF_VLOG(4) << "Parameter sequence graph hash "
<< torch::lazy::HashToString(coll.hash);
Expand Down