-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 5 No.2】Add index_fill / index_fill_ API to Paddle -part #57416
Changes from all commits
d5cbe06
df9168c
e00a55e
c49cd3d
dfb2947
42b76c7
be92d15
7e49400
752035f
bf0ccf6
faa792f
d87864b
59e28c3
ba887af
1b4e076
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5422,3 +5422,104 @@ def unfold(x, axis, size, step, name=None): | |
} | ||
for name, func in __METHODS.items(): | ||
setattr(core.eager.Tensor, name, func) | ||
|
||
|
||
def _index_fill_impl(x, index, axis, value, inplace): | ||
if not isinstance(index, Variable): | ||
raise ValueError("index must be Tensor") | ||
|
||
if not isinstance(value, Variable): | ||
value = paddle.to_tensor(value, dtype=x.dtype) | ||
else: | ||
if len(value.shape) > 0: | ||
raise ValueError("value must be scalar or 0-D tensor") | ||
|
||
x_dim = len(x.shape) | ||
if not (isinstance(axis, int)) or (axis > x_dim - 1) or axis < -x_dim: | ||
raise ValueError( | ||
"The axis should be int, and in range [-rank(x), rank(x))" | ||
) | ||
|
||
if axis < 0: | ||
axis = axis + x_dim | ||
|
||
perm = list(range(len(x.shape))) | ||
perm[0] = axis | ||
perm[axis] = 0 | ||
|
||
if inplace: | ||
paddle.transpose(x, perm) | ||
paddle.index_put_(x, (index,), value) | ||
return x | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 此处,如果是inplace,是否可以不要调用clone+setitem赋值,而是直接使用index_put_赋值; 如果非inplace,是否可以不需要额外的clone操作 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the implementation solution in rfc shoule be also changed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. current implementation is same with rfc API design already There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in rfc this line: |
||
else: | ||
out = paddle.transpose(x, perm) | ||
out = paddle.index_put(out, (index,), value) | ||
out = paddle.transpose(out, perm) | ||
return out | ||
|
||
|
||
def index_fill(x, index, axis, value, name=None): | ||
""" | ||
Outplace version of ``index_fill_`` API, the output Tensor will be inplaced with input ``x``. | ||
Please refer to :ref:`api_paddle_index_fill_`. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
>>> input_tensor = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype='int64') | ||
>>> index = paddle.to_tensor([0, 2], dtype="int32") | ||
>>> value = -1 | ||
>>> res = paddle.index_fill(input_tensor, index, 0, value) | ||
>>> print(input_tensor) | ||
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, | ||
[[1, 2, 3], | ||
[4, 5, 6], | ||
[7, 8, 9]]) | ||
>>> print(res) | ||
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, | ||
[[-1, -1, -1], | ||
[ 4, 5, 6], | ||
[-1, -1, -1]]) | ||
|
||
""" | ||
return _index_fill_impl(x, index, axis, value, False) | ||
|
||
|
||
@inplace_apis_in_dygraph_only | ||
def index_fill_(x, index, axis, value, name=None): | ||
""" | ||
Fill the elements of the input tensor with value by the spcific axis and index. | ||
|
||
Args: | ||
x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64. | ||
index (Tensor): The 1-D Tensor containing the indices to index. | ||
The data type of ``index`` must be int32 or int64. | ||
axis (int): The dimension along which to index. | ||
value (float): The tensor used to fill with. | ||
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. | ||
|
||
Returns: | ||
Tensor, same dimention and dtype with x. | ||
|
||
Examples: | ||
.. code-block:: python | ||
|
||
>>> import paddle | ||
>>> input_tensor = paddle.to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype='int64') | ||
>>> index = paddle.to_tensor([0, 2], dtype="int32") | ||
>>> value = -1 | ||
>>> res = paddle.index_fill_(input_tensor, index, 0, value) | ||
>>> print(input_tensor) | ||
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, | ||
[[-1, -1, -1], | ||
[ 4, 5, 6], | ||
[-1, -1, -1]]) | ||
>>> print(res) | ||
Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True, | ||
[[-1, -1, -1], | ||
[ 4, 5, 6], | ||
[-1, -1, -1]]) | ||
|
||
""" | ||
return _index_fill_impl(x, index, axis, value, True) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
from itertools import combinations | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
from paddle.base import Program | ||
|
||
paddle.enable_static() | ||
|
||
|
||
def compute_index_fill_ref(x, axis, index, value): | ||
perm = list(range(len(x.shape))) | ||
perm[0] = axis | ||
perm[axis] = 0 | ||
|
||
out = np.transpose(x, perm) | ||
out[index] = value | ||
out = np.transpose(out, perm) | ||
return out | ||
|
||
|
||
class TestIndexFillAPIBase(unittest.TestCase): | ||
def setUp(self): | ||
self.init_setting() | ||
self.modify_setting() | ||
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np) | ||
self.index_np = np.array(self.combs[np.random.randint(0, 252)]).astype( | ||
self.index_type | ||
) | ||
|
||
self.place = ['cpu'] | ||
if self.dtype_np == 'float16': | ||
self.place = [] | ||
if paddle.is_compiled_with_cuda(): | ||
self.place.append('gpu') | ||
|
||
def init_setting(self): | ||
self.dtype_np = 'float64' | ||
self.index_type = 'int64' | ||
self.x_shape = (20, 40) | ||
self.index_size = (5,) | ||
self.axis = 0 | ||
self.value = -1 | ||
self.combs = list(combinations(list(range(10)), self.index_size[0])) | ||
|
||
def modify_setting(self): | ||
pass | ||
|
||
def test_static_graph(self): | ||
paddle.enable_static() | ||
for place in self.place: | ||
with paddle.static.program_guard(Program()): | ||
x = paddle.static.data( | ||
name="x", shape=self.x_shape, dtype=self.dtype_np | ||
) | ||
index = paddle.static.data( | ||
name="index", shape=self.index_size, dtype=self.index_type | ||
) | ||
out = paddle.index_fill(x, index, self.axis, self.value) | ||
exe = paddle.static.Executor(place=place) | ||
feed_list = {"x": self.x_np, "index": self.index_np} | ||
pd_res = exe.run( | ||
paddle.static.default_main_program(), | ||
feed=feed_list, | ||
fetch_list=[out], | ||
)[0] | ||
ref_res = compute_index_fill_ref( | ||
self.x_np, self.axis, self.index_np, self.value | ||
) | ||
np.testing.assert_allclose(ref_res, pd_res) | ||
|
||
def test_dygraph(self): | ||
paddle.disable_static() | ||
for place in self.place: | ||
paddle.device.set_device(place) | ||
x_pd = paddle.to_tensor(self.x_np) | ||
index_pd = paddle.to_tensor(self.index_np) | ||
pd_res = paddle.index_fill(x_pd, index_pd, self.axis, self.value) | ||
ref_res = compute_index_fill_ref( | ||
self.x_np, self.axis, self.index_np, self.value | ||
) | ||
np.testing.assert_allclose(ref_res, pd_res) | ||
|
||
def test_errors(self): | ||
data_np = np.random.random((10, 10)).astype(np.float32) | ||
index = paddle.to_tensor([0, 2]) | ||
|
||
def test_index_not_tensor(): | ||
res = paddle.index_fill(data_np, [0, 2], axis=-1, value=-1) | ||
|
||
self.assertRaises(ValueError, test_index_not_tensor) | ||
|
||
def test_value_shape(): | ||
res = paddle.index_fill( | ||
data_np, index, axis=-1, value=paddle.to_tensor([-1, -4]) | ||
) | ||
|
||
self.assertRaises(ValueError, test_value_shape) | ||
|
||
def test_axis_range(): | ||
res = paddle.index_fill(data_np, index, axis=4, value=-1) | ||
|
||
self.assertRaises(ValueError, test_axis_range) | ||
|
||
|
||
class TestIndexFillAPI1(TestIndexFillAPIBase): | ||
def modify_setting(self): | ||
self.dtype_np = 'int64' | ||
self.index_type = 'int32' | ||
self.x_shape = (10, 15, 10) | ||
self.axis = 1 | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 补充下complex类型的测试吧 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. index_put不支持complex类型的输入 增加了float16类型的测试 |
||
class TestIndexFillAPI2(TestIndexFillAPIBase): | ||
def modify_setting(self): | ||
self.dtype_np = 'bool' | ||
self.index_type = 'int32' | ||
self.x_shape = (10, 15, 10) | ||
self.axis = 1 | ||
self.value = True | ||
|
||
|
||
class TestIndexFillAPI3(TestIndexFillAPIBase): | ||
def modify_setting(self): | ||
self.dtype_np = 'float16' | ||
self.x_shape = (10, 15, 10) | ||
self.axis = 1 | ||
self.value = 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The negative
axis
has been processed in L5183 above, so the judgment condition here should beaxis < 0
notaxis < -x_dim
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done