Skip to content

Commit 5516f18

Browse files
authored
[Phi] Add unbind yaml and final state api (#41277)
* add unbind yaml * fix unittest
1 parent edbb398 commit 5516f18

File tree

8 files changed

+97
-8
lines changed

8 files changed

+97
-8
lines changed

paddle/phi/api/lib/api_custom_impl.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,54 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
475475
return api_output;
476476
}
477477

478+
std::vector<Tensor> unbind_impl(const Tensor& input, int axis) {
479+
auto kernel_key_set = ParseKernelKeyByInputArgs(input);
480+
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
481+
482+
Backend kernel_backend = kernel_key.backend();
483+
DataLayout kernel_layout = kernel_key.layout();
484+
DataType kernel_data_type = kernel_key.dtype();
485+
486+
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
487+
"unbind", {kernel_backend, kernel_layout, kernel_data_type});
488+
VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", "
489+
<< kernel_layout << ", " << kernel_data_type << "]";
490+
VLOG(6) << "unbind API kernel: " << kernel;
491+
492+
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
493+
494+
auto dense_input = PrepareData(input, kernel.InputAt(0), {});
495+
496+
// Calculate the number of out tensors
497+
auto input_shape = input.dims();
498+
if (axis < 0) {
499+
axis = input_shape.size() + axis;
500+
}
501+
auto out_num = input_shape[axis];
502+
503+
std::vector<Tensor> out;
504+
auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out);
505+
std::vector<phi::MetaTensor> meta_outs;
506+
meta_outs.reserve(out_num);
507+
std::vector<phi::MetaTensor*> meta_out_ptrs;
508+
meta_out_ptrs.reserve(out_num);
509+
for (int64_t i = 0; i < out_num; ++i) {
510+
meta_outs.push_back(dense_outs[i]);
511+
meta_out_ptrs.push_back(&meta_outs.back());
512+
}
513+
514+
phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs);
515+
516+
using kernel_signature = void (*)(const phi::DeviceContext&,
517+
const phi::DenseTensor&,
518+
int,
519+
std::vector<phi::DenseTensor*>&);
520+
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
521+
(*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs);
522+
523+
return out;
524+
}
525+
478526
////////////////// Backward(grad) api impls //////////////////////
479527

480528
// TODO(chenweihang): the original sum grad op can support higher-level

paddle/phi/api/lib/api_custom_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <vector>
18+
1719
#include "paddle/phi/api/include/tensor.h"
1820
#include "paddle/phi/common/int_array.h"
1921
#include "paddle/phi/common/place.h"
@@ -73,6 +75,8 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
7375
bool multi_precision,
7476
float rescale_grad);
7577

78+
std::vector<Tensor> unbind_impl(const Tensor& input, int axis);
79+
7680
////////////////// Backward(grad) api impls //////////////////////
7781

7882
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,

paddle/phi/infermeta/unary.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2429,7 +2429,7 @@ void TransposeGradInferMeta(const MetaTensor& x,
24292429

24302430
void UnbindInferMeta(const MetaTensor& x,
24312431
int axis,
2432-
std::vector<MetaTensor>* outs) {
2432+
std::vector<MetaTensor*> outs) {
24332433
auto in_dims = x.dims();
24342434
std::vector<int> out_dim;
24352435
axis = axis < 0 ? in_dims.size() + axis : axis;
@@ -2438,11 +2438,11 @@ void UnbindInferMeta(const MetaTensor& x,
24382438
}
24392439
auto out_dims = phi::make_ddim(out_dim);
24402440

2441-
for (size_t i = 0; i < outs->size(); ++i) {
2442-
(*outs)[i].set_dtype(x.dtype());
2443-
(*outs)[i].set_dims(out_dims);
2444-
(*outs)[i].set_layout(x.layout());
2445-
(*outs)[i].share_lod(x);
2441+
for (size_t i = 0; i < outs.size(); ++i) {
2442+
outs[i]->set_dtype(x.dtype());
2443+
outs[i]->set_dims(out_dims);
2444+
outs[i]->set_layout(x.layout());
2445+
outs[i]->share_lod(x);
24462446
}
24472447
}
24482448

paddle/phi/infermeta/unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ void TrilTriuInferMeta(const MetaTensor& x,
365365

366366
void UnbindInferMeta(const MetaTensor& x,
367367
int axis,
368-
std::vector<MetaTensor>* outs);
368+
std::vector<MetaTensor*> outs);
369369

370370
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
371371

python/paddle/fluid/tests/unittests/test_unbind_op.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
import unittest
1818
import numpy as np
1919
from op_test import OpTest, convert_float_to_uint16
20+
import paddle
2021
import paddle.fluid as fluid
2122
import paddle.tensor as tensor
2223
from paddle.fluid import compiler, Program, program_guard, core
24+
from paddle.fluid.framework import _test_eager_guard
2325

2426

2527
class TestUnbind(unittest.TestCase):
@@ -39,6 +41,25 @@ def test_unbind(self):
3941
assert np.array_equal(res_1, input_1[0, 0:100])
4042
assert np.array_equal(res_2, input_1[1, 0:100])
4143

44+
def test_unbind_dygraph(self):
45+
with fluid.dygraph.guard():
46+
np_x = np.random.random([2, 3]).astype("float32")
47+
x = paddle.to_tensor(np_x)
48+
x.stop_gradient = False
49+
[res_1, res_2] = paddle.unbind(x, 0)
50+
self.assertTrue(np.array_equal(res_1, np_x[0, 0:100]))
51+
self.assertTrue(np.array_equal(res_2, np_x[1, 0:100]))
52+
53+
out = paddle.add_n([res_1, res_2])
54+
55+
np_grad = np.ones(x.shape, np.float32)
56+
out.backward()
57+
self.assertTrue(np.array_equal(x.grad.numpy(), np_grad))
58+
59+
def test_unbind_dygraph_final_state(self):
60+
with _test_eager_guard():
61+
self.test_unbind_dygraph()
62+
4263

4364
class TestLayersUnbind(unittest.TestCase):
4465
def test_layers_unbind(self):
@@ -157,6 +178,7 @@ def outReshape(self):
157178
class TestUnbindBF16Op(OpTest):
158179
def setUp(self):
159180
self._set_op_type()
181+
self.python_api = paddle.unbind
160182
self.dtype = self.get_dtype()
161183
self.axis = 0
162184
self.num = 3

python/paddle/tensor/manipulation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1469,6 +1469,9 @@ def unbind(input, axis=0):
14691469
# x3.shape [3, 5]
14701470
14711471
"""
1472+
if in_dygraph_mode():
1473+
return _C_ops.final_state_unbind(input, axis)
1474+
14721475
if not isinstance(axis, (int)):
14731476
raise TypeError("The type of 'axis' must be int, but received %s." %
14741477
(type(axis)))
@@ -1477,7 +1480,7 @@ def unbind(input, axis=0):
14771480
input_shape = input.shape
14781481
axis_ = axis if axis >= 0 else len(input_shape) + axis
14791482
num = input_shape[axis_]
1480-
if paddle.in_dynamic_mode():
1483+
if _in_legacy_dygraph():
14811484
return _C_ops.unbind(input, num, 'axis', axis)
14821485

14831486
helper = LayerHelper("unbind", **locals())

python/paddle/utils/code_gen/api.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,6 +1939,12 @@
19391939
backend : place
19401940
data_type : dtype
19411941

1942+
- api : unbind
1943+
args : (Tensor input, int axis)
1944+
output : Tensor[]
1945+
invoke : unbind_impl(input, axis)
1946+
backward : unbind_grad
1947+
19421948
# unfold
19431949
- api : unfold
19441950
args : (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)

python/paddle/utils/code_gen/backward.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,12 @@
14801480
kernel :
14811481
func : trunc_grad
14821482

1483+
- backward_api : unbind_grad
1484+
forward : unbind (Tensor input, int axis) -> Tensor[](out)
1485+
args : (Tensor[] out_grad, int axis)
1486+
output : Tensor(input_grad)
1487+
invoke : stack(out_grad, axis)
1488+
14831489
- backward_api : unfold_grad
14841490
forward : unfold (Tensor x, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations) -> Tensor(out)
14851491
args : (Tensor x, Tensor out_grad, int[] kernel_sizes, int[] strides, int[] paddings, int[] dilations)

0 commit comments

Comments
 (0)