Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to donate input buffer for dynamo execution #6587

Merged
merged 13 commits into from
Feb 27, 2024
Prev Previous commit
Next Next commit
add SetBufferDonors
  • Loading branch information
JackCaoG committed Feb 24, 2024
commit ebbafda24c4685aa65ed836c0caf83ad4634c95b
26 changes: 26 additions & 0 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <exception>
#include <fstream>
#include <functional>
#include <iostream>
#include <mutex>
#include <set>
#include <stdexcept>
Expand Down Expand Up @@ -60,6 +61,7 @@
#include "tsl/profiler/lib/traceme.h"
#include "xla/literal_util.h"
#include "xla/shape_util.h"
using std::cerr;

namespace torch_xla {
namespace {
Expand Down Expand Up @@ -1212,6 +1214,25 @@ XLAGraphExecutor::BuildInputOutputAliases(
return input_output_alias_pair;
}

std::vector<size_t> XLAGraphExecutor::SetBufferDonors(
LoweringContext* lowering_ctx) {
std::vector<size_t> buffer_donor_indexs;
const std::vector<torch::lazy::BackendDataPtr>& parameters_data =
lowering_ctx->GetParametersData();
for (size_t i = 0; i < parameters_data.size(); ++i) {
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
parameters_data[i]);
if (data->should_donate_buffer()) {
buffer_donor_indexs.push_back(i);
lowering_ctx->builder()->AddBufferDonor(/*param_number=*/i,
/*param_index=*/{});
cerr << "add buffer donor at index " << i << "\n";
}
}
TORCH_LAZY_VALUE_METRIC("InputOutputAliasCount", buffer_donor_indexs.size());
return buffer_donor_indexs;
}

XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
const std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices, const SyncTensorCollection& coll,
Expand Down Expand Up @@ -1243,6 +1264,7 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
ShardingUtil::SetHloSharding(&lowering_ctx);

std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair;
std::vector<size_t> buffer_donor_indices;
// TODO(yeounoh) aliasing is disabled for partitioned computation,
// since the current aliasing compares the unpartitioned input and output
// shapes which can lead to an incorrect aliasing pairs if sharded.
Expand Down Expand Up @@ -1272,8 +1294,12 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
// will later fetch the new value of A, which is incorrect.
// But, when we issue a step barrier (force_ltc_data == true) we have to
// turn everything into DEVICE_DATA, so we can activate aliasing.
std::cerr << "build input output aliasing\n";
input_output_alias_pair =
BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx);
} else if (enable_aliasing) {
std::cerr << "call SetBufferDonors\n";
buffer_donor_indices = SetBufferDonors(&lowering_ctx);
}

xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const std::vector<XLATensorPtr>& tensors,
absl::Span<const size_t> indices, LoweringContext* lowering_ctx);

std::vector<size_t> SetBufferDonors(LoweringContext* lowering_ctx);

// We don't use upstream Compile to have BuildInputOutputAliases.
CompilationResult Compile(const std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices,
Expand Down