Skip to content

[XPU] add bf16/fp16 support for index_put/_grad #69970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions paddle/phi/backends/xpu/xpu3_op_list.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,10 +660,14 @@ XPUOpMap& get_kl3_ops() {
{"index_put",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64})},
{"index_put_grad",
XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::BFLOAT16,
phi::DataType::INT64})},
{"index_sample_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"index_sample",
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/xpu/index_put_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ void IndexPutGradKernel(const Context& dev_ctx,
std::copy(xshape.begin() + int_indices_v.size(),
xshape.end(),
value_shape_bd.begin() + index_shape.size() - 1);
auto value_shape = common::vectorize<int64_t>(value_grad->dims());
int ret = xpu::SUCCESS;
using XPUType = typename XPUTypeTrait<T>::Type;
if (x_grad) {
Expand All @@ -95,6 +94,7 @@ void IndexPutGradKernel(const Context& dev_ctx,
}
}
if (value_grad) {
auto value_shape = common::vectorize<int64_t>(value_grad->dims());
dev_ctx.template Alloc<T>(value_grad);
if (value_shape != value_shape_bd) {
std::vector<int64_t> compress_dims;
Expand Down Expand Up @@ -140,5 +140,7 @@ PD_REGISTER_KERNEL(index_put_grad,
ALL_LAYOUT,
phi::IndexPutGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
11 changes: 9 additions & 2 deletions paddle/phi/kernels/xpu/index_put_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,12 @@ void IndexPutKernel(const Context& dev_ctx,
}
} // namespace phi

PD_REGISTER_KERNEL(
index_put, XPU, ALL_LAYOUT, phi::IndexPutKernel, float, int, int64_t) {}
PD_REGISTER_KERNEL(index_put,
XPU,
ALL_LAYOUT,
phi::IndexPutKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
1 change: 1 addition & 0 deletions test/legacy_test/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def is_complex_test():
not cls.input_shape_is_large
and cls.op_type
not in check_shape_white_list.NEED_TO_FIX_OP_LIST
and not is_xpu_op_test()
):
raise AssertionError(
"Number of element(s) of input should be large than or equal to 100 for "
Expand Down
105 changes: 61 additions & 44 deletions test/xpu/test_index_put_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
create_test_class,
get_xpu_op_support_types,
)
from op_test import convert_float_to_uint16, convert_uint16_to_float
from op_test_xpu import XPUOpTest

import paddle
Expand Down Expand Up @@ -104,11 +105,18 @@ def set_case(self):

def init_data(self):
x_np = ((np.random.random(self.x_shape) - 0.5) * 10.0).astype(
self.dtype
"float32"
)
value_np = (
(np.random.random(self.value_shape) - 0.5) * 10.0
).astype(self.dtype)
).astype("float32")

if self.dtype == np.uint16:
x_np = convert_float_to_uint16(x_np)
value_np = convert_float_to_uint16(value_np)
else:
x_np = x_np.astype(self.dtype)
value_np = value_np.astype(self.dtype)

if self.mixed_indices:
tmp_indices_np1 = gen_indices_np(
Expand Down Expand Up @@ -149,12 +157,21 @@ def init_data(self):
if self.is_all_false:
out_np = x_np
else:
out_np = compute_index_put_ref(
copy.deepcopy(x_np),
self.indices_np,
value_np,
self.accumulate,
)
if self.dtype == np.uint16:
out_np = compute_index_put_ref(
convert_uint16_to_float(copy.deepcopy(x_np)),
self.indices_np,
convert_uint16_to_float(value_np),
self.accumulate,
)
out_np = convert_float_to_uint16(out_np)
else:
out_np = compute_index_put_ref(
copy.deepcopy(x_np),
self.indices_np,
value_np,
self.accumulate,
)
self.outputs = {'out': out_np}

def get_indices_names(self):
Expand All @@ -172,49 +189,49 @@ def test_check_grad(self):
class TestXPUIndexPut1(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = [(16, 16), (16, 16), (1, 16), (1, 16)]
self.value_shape = (16, 16)
self.x_shape = (48, 26, 56)
self.indices_shapes = [(16, 16), (16, 16), (1, 16)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加了两个数据类型之后CI中有超时,这里调小了一些单测的规模

self.value_shape = [16, 16]
self.accumulate = False

class TestXPUIndexPut2(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = [(16, 16), (16, 16), (1, 16), (1, 16)]
self.x_shape = (48, 26, 56)
self.indices_shapes = [(16, 16), (16, 16), (1, 16)]
self.value_shape = (16, 16)
self.accumulate = True

class TestXPUIndexPut3(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (110, 94)
self.indices_shapes = [(110, 94)]
self.value_shape = (5170,)
self.x_shape = (12, 94)
self.indices_shapes = [(12, 94)]
self.value_shape = (564,)
self.accumulate = False

class TestXPUIndexPut4(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.bool_
self.x_shape = (110, 94)
self.indices_shapes = [(110, 94)]
self.value_shape = (5170,)
self.x_shape = (11, 94)
self.indices_shapes = [(11, 94)]
self.value_shape = (564,)
self.accumulate = True

class TestXPUIndexPut5(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (16, 16, 56)
self.x_shape = (17, 32, 26, 36)
self.indices_shapes = ((8, 8), (8, 8), (1, 8))
self.value_shape = (8, 8, 36)
self.accumulate = False

class TestXPUIndexPut6(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (16, 16, 56)
self.x_shape = (17, 32, 26, 36)
self.indices_shapes = ((8, 8), (8, 8), (1, 8))
self.value_shape = (8, 8, 36)
self.accumulate = True

class TestXPUIndexPut7(TestXPUIndexPutOp):
Expand All @@ -237,32 +254,32 @@ def set_case(self):
class TestXPUIndexPut9(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (56,)
self.x_shape = (17, 32, 26, 36)
self.indices_shapes = ((8, 8), (8, 8), (1, 8))
self.value_shape = (36,)
self.accumulate = False

class TestXPUIndexPut10(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.value_shape = (56,)
self.x_shape = (17, 32, 26, 36)
self.indices_shapes = ((8, 8), (8, 8), (8, 8))
self.value_shape = (36,)
self.accumulate = True

class TestXPUIndexPut11(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.x_shape = (17, 32, 26, 36)
self.indices_shapes = ((8, 8), (8, 8), (8, 8))
self.value_shape = (1,)
self.accumulate = False

class TestXPUIndexPut12(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int64
self.x_shape = (110, 42, 56, 56)
self.indices_shapes = ((16, 16), (16, 16), (1, 16))
self.x_shape = (17, 32, 26, 36)
self.indices_shapes = ((8, 8), (8, 8), (1, 8))
self.value_shape = (1,)
self.accumulate = True

Expand Down Expand Up @@ -317,26 +334,26 @@ def set_case(self):
class TestXPUIndexPutMixedIndices(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.x_shape = (17, 32, 16, 36)
self.indices_shapes = ((8, 8), (8, 8))
self.value_shape = (8, 8, 36)
self.accumulate = False

self.mixed_indices = True
self.index_dtype1 = np.bool_
self.indices_shapes1 = [(32,)]
self.indices_shapes1 = [(16,)]

class TestXPUIndexPutMixedIndices1(TestXPUIndexPutOp):
def set_case(self):
self.index_dtype = np.int32
self.x_shape = (110, 42, 32, 56)
self.indices_shapes = ((16, 16), (16, 16))
self.value_shape = (16, 16, 56)
self.x_shape = (17, 32, 16, 36)
self.indices_shapes = ((8, 8), (8, 8))
self.value_shape = (8, 8, 36)
self.accumulate = True

self.mixed_indices = True
self.index_dtype1 = np.bool_
self.indices_shapes1 = [(32,)]
self.indices_shapes1 = [(16,)]


supported_type = get_xpu_op_support_types("index_put")
Expand All @@ -357,7 +374,7 @@ def setUp(self):
def init_dtype_type(self):
self.dtype_np = np.float32
self.index_type_np = np.int64
self.x_shape = (100, 110)
self.x_shape = (50, 55)
self.indices_shapes = [(21,), (21,)]
self.value_shape = (21,)
self.dtype_pd = paddle.float32
Expand Down