Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 169 additions & 15 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4223,6 +4223,30 @@ def aten_index_put(
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
if (
len(indices) > 1
and any(
isinstance(index, torch.onnx._internal.exporter._tensors.SymbolicTensor) # pylint: disable=protected-access
for index in indices
)
and len(values.shape) == 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this condition for? I am just trying to understand the assumptions/conditions for this special case.

):
return _aten_index_put_dynamic(self, indices, values, accumulate=accumulate)

n_none = [i for i, ind in enumerate(indices) if ind is not None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess n_none is short for not_none ... may be better to call it not_none

if (
len(n_none) == 1
and len(indices[n_none[0]].shape) == 1
and len(self.shape) == len(values.shape)
):
return _aten_index_put_scatter_nd(self, indices, values, accumulate)

if len(indices) == 1 and set(indices[0].shape[:-1]) == {1} and indices[0].shape[0] == 1:
# shape(self) = (5,5), shape(indices[0]) = (1,2), shape(values) = (2,5)
# This case was only found in ops_data test.
return _aten_index_put_scatter_nd(
self, [op.Reshape(indices[0], [-1])], values, accumulate
)

def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Remove ones until the rank of reshape_list matches values_shape.
Expand All @@ -4235,7 +4259,13 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
# the reshape list should be : [[2, 1], [1, 3], [2, 1]]
for i, r in enumerate(reshape_list):
if r not in (1, values_shape[i]):
value_index = values_shape.index(r)
try:
value_index = values_shape.index(r)
except ValueError as e:
raise RuntimeError(
f"Unable to find element {r!r} in shape {values_shape}, "
f"reshape_list={reshape_list}"
) from e
# Swap elements
# For the example above the current reshape list is [1, 2] for last dim,
# to make it broadcastable, we swap the elements
Expand All @@ -4259,15 +4289,22 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
reshape_update = self.shape[i]
else:
idx = indices[i]
reshape_update = math.prod(idx.shape)
# when Index is more than 1D, flatten it and also the values shape
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
# Indices -> (2*4,) and values shape (2*4, 32)
if len(idx.shape) > 1:
values_shape = (reshape_update, *values_shape[len(idx.shape) :])

# Flatten index (always working with 1D index in each dim)
idx = op.Reshape(idx, [-1])
if all(isinstance(s, int) for s in idx.shape):
reshape_update = math.prod(idx.shape)
# when Index is more than 1D, flatten it and also the values shape
# Example: self shape: (10, 3), indices[i] shape: (2, 4), values shape: (2, 4, 3)
# Indices -> (2*4,) and values shape (2*4, 32)
if len(idx.shape) > 1:
values_shape = (reshape_update, *values_shape[len(idx.shape) :])

# Flatten index (always working with 1D index in each dim)
idx = op.Reshape(idx, [-1])
else:
raise RuntimeError(
f"Unable to handle index {indices[i]} for axis={i} "
f"because one of the dimension is not static as shape="
f"{idx.shape}, indices={indices}"
)

# Create a reshape pattern: one value per index dimension,
# with the current dimension set to the update size.
Expand All @@ -4292,14 +4329,131 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape):
# Flatten values to match the indices
flat_values = op.Reshape(values, [-1])

if accumulate:
result = op.ScatterND(self, new_index, flat_values, reduction="add")
else:
result = op.ScatterND(self, new_index, flat_values)

scatter_kwargs = dict(reduction="add") if accumulate else {}
result = op.ScatterND(self, new_index, flat_values, **scatter_kwargs)
return result


def _aten_index_put_scatter_nd(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

n_none = [i for i, ind in enumerate(indices) if ind is not None]
assert len(n_none) == 1, f"Unable to handle that case: n_none={n_none}"
unsq = op.Unsqueeze(indices[n_none[0]], _1dint(1))
if n_none[0] == 0:
return op.ScatterND(x, unsq, values, reduction="add" if accumulate else "none")

perm = list(range(len(x.shape)))
perm[n_none[0]], perm[0] = perm[0], perm[n_none[0]]
return op.Transpose(
op.ScatterND(
op.Transpose(x, perm=perm),
unsq,
op.Transpose(values, perm=perm),
reduction="add" if accumulate else "none",
),
perm=perm,
)


def _aten_index_put_dynamic(
x: TReal,
indices: Sequence[INT64],
values: TReal,
accumulate: bool = False,
) -> TReal:
def _1dint(i: int):
return op.Constant(value_ints=ir.AttrInt64s("value_ints", [i]))

def _0dint(i: int):
return op.Constant(value_int=ir.AttrInt64("value_int", i))

def _make_range_or_cast(ind, shape_x, static_shape: bool, dim: int):
if ind is not None:
return op.Cast(ind, to=INT64.dtype), False
return (
op.Cast(
op.Range( # Range does not return a typed result
_0dint(0),
op.Squeeze(op.Shape(x, start=dim, end=dim + 1)),
_0dint(1),
),
to=INT64.dtype,
),
True,
)

shape_x = op.Shape(x)
exped = []
fixed = []
reshape_value_shape2 = []
expand_value_shape = []
for i, ind in enumerate(indices):
if isinstance(ind, torch.onnx._internal.exporter._tensors.SymbolicTensor): # pylint: disable=protected-access
ind.dtype = ir.DataType.INT64
ind, expanded = _make_range_or_cast(ind, shape_x, False, i)
if expanded:
exped.append((i, ind))
expand_value_shape.append(op.Shape(x, start=i, end=i + 1))
reshape_value_shape2.append(_1dint(1))
else:
expand_value_shape.append(_1dint(1))
reshape_value_shape2.append(op.Shape(ind))
fixed.append((i, ind))

reshape_value_shape1 = [_1dint(1)] * len(indices)
if len(fixed) <= 1:
reshape_value_shape1 = None
elif fixed:
reshape_value_shape1[fixed[-1][0]] = _1dint(-1)

def _mkstride(x, i):
if i >= len(x.shape) - 1:
return _1dint(1)
if i == len(x.shape) - 2:
return op.Shape(x, start=i + 1)
return op.ReduceProd(op.Shape(x, start=i + 1), keepdims=1)

shape = [1] * (len(x.shape) + 1)
r_fixed = []
if fixed:
new_shape = shape.copy()
new_shape[-1] = -1
r_fixed = [op.Reshape(op.Mul(_mkstride(x, i), f), new_shape) for i, f in fixed]

r_exped = []
for i, e in exped:
new_shape = shape.copy()
new_shape[i] = -1
r_exped.append(op.Reshape(op.Mul(_mkstride(x, i), e), new_shape))

# final sum
unflat = None
for a in [*r_fixed, *r_exped]:
if unflat is None:
unflat = a
continue
unflat = op.Add(unflat, a)

# value_shape
expanded_values = values
if reshape_value_shape1 is not None:
expanded_values = op.Reshape(expanded_values, op.Concat(*reshape_value_shape1, axis=0))
expanded_values = op.Expand(expanded_values, op.Concat(*expand_value_shape, axis=0))
flat_ind = op.Reshape(unflat, _1dint(-1))
expanded_values = op.Reshape(expanded_values, _1dint(-1))
flat_x = op.Reshape(x, _1dint(-1))
scat_kwargs = {"reduction": "add"} if accumulate else {}
flat_up_x = op.ScatterElements(flat_x, flat_ind, expanded_values, **scat_kwargs)
return op.Reshape(flat_up_x, op.Shape(x))


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
self: TReal,
Expand Down
104 changes: 104 additions & 0 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import unittest

import numpy as np
import torch
from torch.onnx._internal.exporter import _testing

Expand Down Expand Up @@ -225,6 +226,109 @@ def forward(self, q, k, v):
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_dynamic(self):
for dimension in [3, 4, 2]:
with self.subTest(dimension=dimension):

class Model(torch.nn.Module):
def __init__(self, dimension):
super().__init__()
self.params = torch.zeros(
(4, 5)
if dimension == 2
else ((2, 4, 5) if dimension == 3 else (1, 1, 4, 5))
)
self.dimension = dimension

def forward(self, update, index1, index2):
copy = self.params.clone()
if self.dimension == 2:
copy[index1, index2] = update
elif self.dimension == 3:
copy[:, index1, index2] = update
else:
copy[:, :, index1, index2] = update
return copy

update = (torch.arange(2) + 10).reshape((2,)).to(torch.float32)
index1 = torch.tensor([1, 2], dtype=torch.int64)
index2 = torch.tensor([3, 4], dtype=torch.int64)
feeds = dict(zip(["update", "index1", "index2"], (update, index1, index2)))
onnx_program = torch.onnx.export(
Model(dimension),
tuple(feeds.values()),
input_names=["update", "index1", "index2"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes={
"update": {0: "dn"},
"index1": {0: "dn"},
"index2": {0: "dn"},
},
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_55_12_25(self):
class Model(torch.nn.Module):
def forward(self, x, index, update):
return torch.ops.aten.index_put(x, [index], update)

x = torch.zeros((6, 5), dtype=torch.float32)
index = torch.tensor([[2, 1]], dtype=torch.int64)
update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, index, update),
input_names=["x", "index", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_55_2_25(self):
class Model(torch.nn.Module):
def forward(self, x, index, update):
return torch.ops.aten.index_put(x, [index], update, accumulate=True)

x = torch.ones((6, 5), dtype=torch.float32)
index = torch.tensor([4, 3], dtype=torch.int64)
update = (torch.arange(10) + 10).reshape((2, -1)).to(torch.float32)
onnx_program = torch.onnx.export(
Model(),
(x, index, update),
input_names=["x", "index", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
)
_testing.assert_onnx_program(onnx_program)

def test_index_put_scatter_nd(self):
class Model(torch.nn.Module):
def forward(self, x, index, update):
x = x.clone()
return torch.ops.aten.index_put(x, [None, index, None], update)

shape = (2, 3, 2)
N = int(np.prod(shape))
x = torch.arange(N, dtype=torch.float32).reshape(shape)
update = (torch.arange(N, dtype=torch.float32).reshape(shape) + 1) * 100
index = ((torch.arange(shape[-2])).to(torch.int64) + 1) % shape[-2]

feeds = dict(zip(["x", "index", "update"], (x, index, update)))
onnx_program = torch.onnx.export(
Model(),
tuple(feeds.values()),
input_names=["x", "index", "update"],
output_names=["output"],
opset_version=18,
dynamo=True,
dynamic_shapes=({0: "a", 1: "b", 2: "c"}, {0: "d"}, {0: "e", 1: "f", 2: "g"}),
)
_testing.assert_onnx_program(onnx_program)

def test_bitwise_and_scalar(self):
class Model(torch.nn.Module):
def forward(self, x):
Expand Down
5 changes: 4 additions & 1 deletion tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,10 @@ def _im2col_input_wrangler(
def _index_put_input_wrangler(
args: list[Any], kwargs: dict[str, Any]
) -> tuple[list[Any], dict[str, Any]]:
args[1] = [np.array(elem) for elem in args[1]]
args[1] = [
(elem.detach().cpu().numpy() if hasattr(elem, "detach") else np.array(elem))
for elem in args[1]
]
return args, kwargs


Expand Down
Loading