Skip to content

Commit 4046080

Browse files
committed
[Torch] Add index_put operator
1 parent d280118 commit 4046080

File tree

2 files changed

+60
-0
lines changed

2 files changed

+60
-0
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,32 @@ def scatter(self, inputs, input_types):
20102010
src = inputs[3]
20112011
return _op.transform.scatter(data, index, src, axis)
20122012

2013+
def index_put(self, inputs, input_types):
2014+
in_tensor = inputs[0]
2015+
indices = inputs[1]
2016+
values = inputs[2]
2017+
accumulate = inputs[3]
2018+
# accumulate parameter is ignored.
2019+
# torch.index_put default is False but Relay.scatter_nd accumulates values.
2020+
# We assume there is no duplicate indices in torch.index_put input
2021+
if not accumulate:
2022+
logging.warning(
2023+
"torch.index_put accumulate parameter is False. "
2024+
"TVM uses tvm.relay.scatter_nd operator which accumulates values. "
2025+
"Make sure there is no duplicate indices in torch.index_put input."
2026+
)
2027+
# Relay scatter_nd does not support input tensor
2028+
# We assume that torch.index_put is used with empty zero-values input tensor
2029+
# scatter_nd will create empty zero-values tensor with a given shape
2030+
out_shape = self.infer_shape(in_tensor)
2031+
logging.warning(
2032+
"tvm.relay.scatter_nd operator does not support input tensor parameter. "
2033+
"TVM assumes that torch.index_put is used with empty zero-values input tensor"
2034+
)
2035+
# Combine array of index tensors into one index tensor with shape (N,_)
2036+
index_tensor = _op.stack(indices, axis=0)
2037+
return _op.transform.scatter_nd(values, index_tensor, out_shape)
2038+
20132039
def scalar_tensor(self, inputs, input_types):
20142040
data = inputs[0]
20152041
cast_map = {
@@ -2326,6 +2352,8 @@ def create_convert_map(self):
23262352
"aten::nonzero": self.nonzero,
23272353
"aten::nonzero_numpy": self.nonzero_numpy,
23282354
"aten::scatter": self.scatter,
2355+
"aten::index_put": self.index_put,
2356+
"aten::index_put_": self.index_put,
23292357
"aten::scalar_tensor": self.scalar_tensor,
23302358
"aten::__interpolate": self.interpolate,
23312359
"aten::IntImplicit": self.identity,

tests/python/frontend/pytorch/test_forward.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,6 +3327,38 @@ def test_fn_scatter_add(dim):
33273327
verify_trace_model(test_fn_scatter_add(1), [in_data, in_index, in_src], targets)
33283328

33293329

3330+
def test_forward_index_put():
3331+
# torch.index_put for 2D tensor and default accumulate (False)
3332+
def test_fn_index_put2():
3333+
return lambda data, xidx, yidx, values: torch.index_put(
3334+
data, indices=[xidx, yidx], values=values
3335+
)
3336+
3337+
# torch.index_put for 3D tensor and accumulate=True
3338+
def test_fn_index_put3a():
3339+
return lambda data, xidx, yidx, zidx, values: torch.index_put(
3340+
data, indices=[xidx, yidx, zidx], values=values, accumulate=True
3341+
)
3342+
3343+
shape = (3, 5)
3344+
in_data = torch.zeros(shape)
3345+
xidx = torch.tensor([0, 1, 2, 2])
3346+
yidx = torch.tensor([0, 1, 3, 4])
3347+
values = torch.tensor([2.0, 4.0, 7.0, 9.0])
3348+
3349+
targets = ["llvm", "cuda"]
3350+
verify_trace_model(test_fn_index_put2(), [in_data, xidx, yidx, values], targets)
3351+
3352+
shape = (3, 5, 3)
3353+
in_data = torch.zeros(shape)
3354+
xidx = torch.tensor([0, 1, 2, 2, 0])
3355+
yidx = torch.tensor([0, 1, 3, 4, 0])
3356+
zidx = torch.tensor([0, 1, 1, 2, 0])
3357+
values = torch.tensor([2.0, 4.0, 7.0, 9.0, 1.0])
3358+
3359+
verify_trace_model(test_fn_index_put3a(), [in_data, xidx, yidx, zidx, values], targets)
3360+
3361+
33303362
def test_numel():
33313363
class Numel(Module):
33323364
def forward(self, data):

0 commit comments

Comments
 (0)