Skip to content

Commit 1bdbb22

Browse files
committed
[Torch] Add index_put operator
1 parent d280118 commit 1bdbb22

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,29 @@ def scatter(self, inputs, input_types):
20102010
src = inputs[3]
20112011
return _op.transform.scatter(data, index, src, axis)
20122012

2013+
def index_put(self, inputs, input_types):
2014+
in_tensor = inputs[0]
2015+
indices = inputs[1]
2016+
values = inputs[2]
2017+
accumulate = inputs[3]
2018+
# accumulate parameter is ignored.
2019+
# torch.index_put default is False but Relay.scatter_nd accumulates values.
2020+
# We assume there is no duplicate indices in torch.index_put input
2021+
if not accumulate:
2022+
logging.warning("torch.index_put accumulate parameter is False. "
2023+
"TVM uses tvm.relay.scatter_nd operator which accumulates values. "
2024+
"Make sure there is no duplicate indices in torch.index_put input.")
2025+
# Relay scatter_nd does not support input tensor
2026+
# We assume that torch.index_put is used with empty zero-values input tensor
2027+
# scatter_nd will create empty zero-values tensor with a given shape
2028+
out_shape = self.infer_shape(in_tensor)
2029+
logging.warning("tvm.relay.scatter_nd operator does not support input tensor parameter. "
2030+
"TVM assumes that torch.index_put is used with empty zero-values input tensor")
2031+
# Combine array of index tensors into one index tensor with shape (N,_)
2032+
indices_expdim = [self.unsqueeze((x, 0), None) for x in indices]
2033+
indices_concat = self.concatenate((indices_expdim, 0), None)
2034+
return _op.transform.scatter_nd(values, indices_concat, out_shape)
2035+
20132036
def scalar_tensor(self, inputs, input_types):
20142037
data = inputs[0]
20152038
cast_map = {
@@ -2326,6 +2349,8 @@ def create_convert_map(self):
23262349
"aten::nonzero": self.nonzero,
23272350
"aten::nonzero_numpy": self.nonzero_numpy,
23282351
"aten::scatter": self.scatter,
2352+
"aten::index_put": self.index_put,
2353+
"aten::index_put_": self.index_put,
23292354
"aten::scalar_tensor": self.scalar_tensor,
23302355
"aten::__interpolate": self.interpolate,
23312356
"aten::IntImplicit": self.identity,

tests/python/frontend/pytorch/test_forward.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,6 +3327,36 @@ def test_fn_scatter_add(dim):
33273327
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)
33283328

33293329

3330+
def test_forward_index_put():
3331+
# torch.index_put for 2D tensor and default accumulate (False)
3332+
def test_fn_index_put2():
3333+
return lambda data, xidx, yidx, values: \
3334+
torch.index_put(data, indices=[xidx, yidx], values=values)
3335+
3336+
# torch.index_put for 3D tensor and accumulate=True
3337+
def test_fn_index_put3a():
3338+
return lambda data, xidx, yidx, zidx, values: \
3339+
torch.index_put(data, indices=[xidx, yidx, zidx], values=values, accumulate=True)
3340+
3341+
shape = (3, 5)
3342+
in_data = torch.zeros(shape)
3343+
xidx = torch.tensor([0, 1, 2, 2])
3344+
yidx = torch.tensor([0, 1, 3, 4])
3345+
values = torch.tensor([2.0, 4.0, 7.0, 9.0])
3346+
3347+
targets = ["llvm", "cuda"]
3348+
verify_trace_model(test_fn_index_put2(), [in_data, xidx, yidx, values], targets)
3349+
3350+
shape = (3, 5, 3)
3351+
in_data = torch.zeros(shape)
3352+
xidx = torch.tensor([0, 1, 2, 2, 0])
3353+
yidx = torch.tensor([0, 1, 3, 4, 0])
3354+
zidx = torch.tensor([0, 1, 1, 2, 0])
3355+
values = torch.tensor([2.0, 4.0, 7.0, 9.0, 1.0])
3356+
3357+
verify_trace_model(test_fn_index_put3a(), [in_data, xidx, yidx, zidx, values], targets)
3358+
3359+
33303360
def test_numel():
33313361
class Numel(Module):
33323362
def forward(self, data):

0 commit comments

Comments
 (0)