Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API to donate input buffer for dynamo execution #6587

Merged
merged 13 commits into from
Feb 27, 2024
2 changes: 1 addition & 1 deletion test/dynamo/test_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def fn_fallback(t):
xla_dynamo_res = dynamo_fn(t_xla)
self.assertTrue(torch.allclose(cpu_res, xla_dynamo_res.cpu()))
self.assertEqual(met.metric_data('CompileTime')[0], 3)
self.assertEqual(met.metric_data('ExecuteTime')[0], 9)
self.assertEqual(met.metric_data('ExecuteTime')[0], 7)

# Second tracing
met.clear_all()
Expand Down
150 changes: 150 additions & 0 deletions test/dynamo/test_dynamo_aliasing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
from torch_xla.core.dynamo_bridge import AliasWithBufferDonorContext


class TestBufferDonationUtil(unittest.TestCase):

def test_hash_with_buffer_donor(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
res = torch.cos(input)
hash_no_donor = torch_xla._XLAC._get_graph_hash([res])
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
# without the AliasWithBufferDonorContext, buffer donor will be ignored,
# so we still expect the hash to be the same.
hash_with_donor = torch_xla._XLAC._get_graph_hash([res])
self.assertEqual(hash_no_donor, hash_with_donor)

with AliasWithBufferDonorContext(True) as context:
hash_with_donor_and_context = torch_xla._XLAC._get_graph_hash([res])
self.assertNotEqual(hash_no_donor, hash_with_donor_and_context)


class TestDynamoBufferDonationAliasing(unittest.TestCase):

def dummy_inplace_add(self, input):
input += 1
return

def dummy_add(self, input):
return input + 1

def test_manual_buffer_donation(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
dummy_inplace_add_compiled = torch.compile(
self.dummy_inplace_add, backend='openxla')

met.clear_all()
# input is a device_data, we should be able to set the buffer donation field.
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
# make sure buffer donation setting is correctly updated
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
self.assertIn('XlaSetBufferDonation', met.counter_names())
self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1)
dummy_inplace_add_compiled(input)
torch.allclose(input_cloned.cpu() + 1, input.cpu())

def test_manual_buffer_donation_for_non_inplce_op(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
dummy_add_compiled = torch.compile(self.dummy_add, backend='openxla')

met.clear_all()
# input is a device_data, we should be able to set the buffer donation field.
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
# make sure buffer donation setting is correctly updated
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
self.assertIn('XlaSetBufferDonation', met.counter_names())
self.assertEqual(met.counter_value('XlaSetBufferDonation'), 1)

res = dummy_add_compiled(input)
# check input's buffer has been aliased.
xm.wait_device_ops()
self.assertIn('Data Handle: Deleted',
torch_xla._XLAC._get_xla_tensor_debug_info(input))
torch.allclose(input_cloned.cpu() + 1, res.cpu())

def test_manual_buffer_donation_for_inplce_op_repeat(self):
# use a different function than above dummy add otherwise XLA won't recompile
def dummy_inplace(input):
input += (0.3 * torch.cos(input))

device = xm.xla_device()
input = torch.randn(5, 5).to(device)
input_cloned = torch.clone(input)
dummy_inplace_add_compiled = torch.compile(dummy_inplace, backend='openxla')
xm.mark_step()
met.clear_all()
# input is a device_data, we should be able to set the buffer donation field.
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
# make sure buffer donation setting is correctly updated
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))

for _ in range(100):
dummy_inplace_add_compiled(input)
# should_donate_buffer field is attached to the buffer and won't be inherited to
# the output buffer(unless execution is a no-op). However dynamo don't track this
# field so it will keep executing the graph with input buffer being aliased.
self.assertFalse(torch_xla._XLAC._get_buffer_donation(input))
# there shouldn't be any recompilation even `should_donate_buffer` field changed after
# first execution. This is because Dynamo does not trace this internal field for xla.
self.assertEqual(met.metric_data('CompileTime')[0], 1)

def test_buffer_donation_on_non_data_tensor(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
res = input + 1

met.clear_all()
# res now points to a `Add` IR, only data's buffer can be aliased
self.assertFalse(torch_xla._XLAC._set_buffer_donation(res, True))
self.assertFalse(torch_xla._XLAC._get_buffer_donation(res))
self.assertNotIn('XlaSetBufferDonation', met.counter_names())


class TestNonDynamoBufferDonationAliasing(unittest.TestCase):

def dummy_fn(self, input):
return torch.cos(torch.sin(input))

# Currently let's skip buffer donation api for the non-dynamo use case
def test_buffer_donation_skip_for_non_dynamo(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
xm.mark_step()
met.clear_all()

# We should be able to set buffer donation for input tensor, but when mark_step
# triggered, the buffer donation should be ignored.
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
res = self.dummy_fn(input)
xm.mark_step()
# Make sure that input buffer is not aliased and can be used for other compuations.
# Also make sure that buffer_donation will not trigger recompilation in non-dynamo.
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False))
res2 = self.dummy_fn(input)
xm.mark_step()
torch.allclose(res.cpu(), res2.cpu())
self.assertEqual(met.metric_data('CompileTime')[0], 1)

def test_no_op_mark_step_keep_buffer_donation(self):
device = xm.xla_device()
input = torch.randn(5, 5).to(device)
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
xm.mark_step()
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
xm.mark_step()
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/test_metrics.py"
run_test "$CDIR/test_zero1.py"
run_test "$CDIR/dynamo/test_dynamo_integrations_util.py"
run_test "$CDIR/dynamo/test_dynamo_aliasing.py"
run_test "$CDIR/dynamo/test_dynamo.py"
run_test "$CDIR/dynamo/test_bridge.py"
run_test "$CDIR/dynamo/test_num_output.py"
Expand Down
31 changes: 26 additions & 5 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@
ptxla_debug = int(os.environ.get('PT_XLA_DEBUG', '0')) == 1


class AliasWithBufferDonorContext(object):

def __init__(self, should_alias: bool):
self.should_alias = should_alias

def __enter__(self):
self.env_inited = 'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR' in os.environ
if self.env_inited:
self.env_saved = os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR']
os.environ[
'XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = '1' if self.should_alias else '0'

def __exit__(self, exc_type, exc_val, exc_tb):
if self.env_inited:
os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR'] = self.env_saved
else:
del os.environ['XLA_SHOULD_ALIAS_WITH_BUFFER_DONOR']


@dataclasses.dataclass
class GraphInputMatcher:
"""
Expand Down Expand Up @@ -307,9 +326,10 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
# calculate graph hash
dumb_return_handler = DumbReturnHandler(xla_args, args_and_out,
xla_args_need_update_bool)
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
if dynamo_debug:
print("graph_hash", graph_hash)
with AliasWithBufferDonorContext(True) as context:
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
if dynamo_debug:
print("graph_hash", graph_hash)

# Collect all device data nodes that is needed to compute the args_and_out
# and wrap those device data nodes inside a at::tensor(graph_input_xla_values).
Expand All @@ -328,8 +348,9 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
graph_input_tensor_ids,
graph_input_xla_values,
xla_args_tensor_ids)
# compiles and cache graph rooted at tensors in 'args_and_out'
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
with AliasWithBufferDonorContext(True) as context:
# compiles and cache graph rooted at tensors in 'args_and_out'
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])

# Restore the origional `xla_args`. Dynamo passed the real tensor as
# `xla_args`` and we performend the tracing on them. During the tracing,
Expand Down
24 changes: 16 additions & 8 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,8 @@ xla::XlaOp XlaHelpers::PromotedLogicalUnaryOp(
xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
const xla::XlaComputation& computation,
const std::vector<xla::Shape>& parameter_shapes,
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair) {
const std::vector<std::pair<int64_t, int64_t>>& input_output_alias_pair,
const std::vector<size_t>& buffer_donor_indices) {
xla::XlaBuilder builder(computation.proto().name());

// Construct a single tuple parameter.
Expand All @@ -928,13 +929,20 @@ xla::StatusOr<xla::XlaComputation> XlaHelpers::WrapXlaComputation(
xla::XlaOp orig_result = xla::Call(&builder, computation, inner_params);

// Rebuild aliasing.
for (const auto& [input_index, output_index] : input_output_alias_pair) {
// Both input and output will be a tuple so parameter_number will always
// be
// 0
builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}),
/*param_number=*/0,
/*param_index=*/xla::ShapeIndex({input_index}));
if (input_output_alias_pair.size() > 0) {
for (const auto& [input_index, output_index] : input_output_alias_pair) {
// Both input and output will be a tuple so parameter_number will always
// be
// 0
builder.SetUpAlias(/*output_index=*/xla::ShapeIndex({output_index}),
/*param_number=*/0,
/*param_index=*/xla::ShapeIndex({input_index}));
}
} else if (buffer_donor_indices.size() > 0) {
for (size_t i : buffer_donor_indices) {
builder.AddBufferDonor(/*param_number=*/0,
/*param_index=*/xla::ShapeIndex({i}));
}
}

return builder.Build(orig_result);
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ class XlaHelpers {
static xla::StatusOr<xla::XlaComputation> WrapXlaComputation(
const xla::XlaComputation& computation,
const std::vector<xla::Shape>& parameter_shapes,
std::vector<std::pair<int64_t, int64_t>> input_output_alias_pair);
const std::vector<std::pair<int64_t, int64_t>>& input_output_alias_pair,
const std::vector<size_t>& buffer_donor_indices);

static torch::lazy::Shape ConvertXlaShapeToLazy(const xla::Shape& shape);

Expand Down
54 changes: 54 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,60 @@ void InitXlaModuleBindings(py::module m) {
xtensor->MarkDynamicDimension(dim);
});

// This api will set the `should_donate_buffer_` field in the
// ComputationClient::Data. This api is currently only useful if you are
// running with `torch.compile`. Buffer assocaited with data with
// `should_donate_buffer_` set to true will be donated to the output, You
// should only use this api if
// 1. You are using torch.compile
// 2. You will inplace update a tensor in the `torch.compiled` function(so the
// currnet buffer can be donated after compuation)
m.def("_set_buffer_donation",
[](at::Tensor& input, bool should_donate) -> bool {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
bool buffer_donation_updated = false;
if (!xtensor) {
// input tensor is not a XLATensor, return here.
} else if (xtensor->CurrentDataHandle() != nullptr) {
auto data =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->CurrentDataHandle());
data->set_should_donate_buffer(should_donate);
buffer_donation_updated = true;
} else if (xtensor->CurrentIrValue().node != nullptr) {
torch::lazy::NodePtr node = xtensor->CurrentIrValue().node;
auto device_data = torch_xla::DeviceData::Cast(node.get());
if (device_data != nullptr) {
device_data->set_buffer_donation(should_donate);
buffer_donation_updated = true;
}
}
if (buffer_donation_updated) {
TORCH_LAZY_COUNTER("XlaSetBufferDonation", 1);
}
return buffer_donation_updated;
});

m.def("_get_buffer_donation", [](const at::Tensor& input) -> bool {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
if (!xtensor) {
return false;
} else if (xtensor->CurrentDataHandle() != nullptr) {
auto data = std::dynamic_pointer_cast<runtime::ComputationClient::Data>(
xtensor->CurrentDataHandle());
return data->should_donate_buffer();
} else if (xtensor->CurrentIrValue().node != nullptr) {
auto device_data =
torch_xla::DeviceData::Cast(xtensor->CurrentIrValue().node.get());
if (device_data != nullptr) {
return device_data->get_buffer_donation();
} else {
return false;
}
}
return false;
});

// -------------Dynamo Integration API Start-------------------------
/*
* Return tensor ids and at::tensors for all DeviceData nodes that is needed
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/ops/device_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ class DeviceData : public XlaNode {
return data_;
}

void set_buffer_donation(bool should_donate_buffer) {
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_)
->set_should_donate_buffer(should_donate_buffer);
}

bool get_buffer_donation() {
return std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data_)
->should_donate_buffer();
}

// With SPMD sharding propagation, we need to update the unpartitioned
// backend data with a partitioned one in the node operands. Note that
// this is permitted only if the node holds a placeholder.
Expand Down
13 changes: 11 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,26 @@ class ComputationClient {
class Data : public torch::lazy::BackendData {
public:
// TODO set Device and torch::lazy_shape correctly
Data(std::string device, xla::Shape shape)
Data(std::string device, xla::Shape shape,
bool should_donate_buffer = false)
: torch::lazy::BackendData(ParseDeviceString(device),
torch::lazy::Shape()),
xla_device_(device),
xla_shape_(std::move(shape)) {}
xla_shape_(std::move(shape)),
should_donate_buffer_(should_donate_buffer) {}

virtual ~Data() {}

const std::string& device() const { return xla_device_; }

const xla::Shape& shape() const { return xla_shape_; }

bool should_donate_buffer() const { return should_donate_buffer_; }

void set_should_donate_buffer(bool should_donate_buffer) {
should_donate_buffer_ = should_donate_buffer;
}

virtual std::string ToString() const = 0;

virtual bool HasSharding() const = 0;
Expand All @@ -72,6 +80,7 @@ class ComputationClient {
private:
std::string xla_device_;
xla::Shape xla_shape_;
bool should_donate_buffer_;
};

using DataPtr = std::shared_ptr<Data>;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class PjRtComputationClient : public ComputationClient {
if (HasValue()) {
ss << reinterpret_cast<std::uintptr_t>(buffer.get()) << "\n";
} else {
ss << "None\n";
ss << (buffer == nullptr ? "None" : "Deleted") << "\n";
}
return ss.str();
}
Expand Down
Loading
Loading