Skip to content

Commit ff4bdac

Browse files
authored
fix a bug of slice by none index (#34877)
1 parent fc6b4a5 commit ff4bdac

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

paddle/fluid/pybind/imperative.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,29 @@ void BindImperative(py::module *m_ptr) {
921921
axis -= len;
922922
}
923923

924+
// Deal with cases that there are more than one
925+
// prefix none index, For example:
926+
// [None, None, :, :, None]
927+
// the none_axes int the return of ParseIndexingSlice is:
928+
// [0, 0, 2 ]
929+
// according to the interface of "unsqueeze2",
930+
// we should convert it to:
931+
// [0, 0, 4 ]
932+
int prefix_zero_cnt = 0;
933+
for (const auto &axis : none_axes) {
934+
if (axis == 0) {
935+
prefix_zero_cnt++;
936+
} else {
937+
break;
938+
}
939+
}
940+
if (prefix_zero_cnt > 0) {
941+
int none_axes_num = static_cast<int>(none_axes.size());
942+
for (int i = prefix_zero_cnt; i < none_axes_num; ++i) {
943+
none_axes[i] += prefix_zero_cnt;
944+
}
945+
}
946+
924947
imperative::NameVarBaseMap ins = {{"X", {out}}};
925948
framework::AttributeMap attrs = {{"axes", none_axes}};
926949
auto new_out = std::shared_ptr<imperative::VarBase>(

python/paddle/fluid/tests/unittests/test_var_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def _test_none_index(self):
711711
var_tensor[None, 2, None, 1].numpy(),
712712
var_tensor[None].numpy(),
713713
var_tensor[0, 0, None, 0, 0, None].numpy(),
714+
var_tensor[None, None, 0, ..., None].numpy(),
714715
var_tensor[0, 1:10:2, None, None, ...].numpy(),
715716
]
716717

@@ -724,11 +725,13 @@ def _test_none_index(self):
724725
self.assertTrue(np.array_equal(var[7], np_value[None]))
725726
self.assertTrue(
726727
np.array_equal(var[8], np_value[0, 0, None, 0, 0, None]))
728+
self.assertTrue(
729+
np.array_equal(var[9], np_value[None, None, 0, ..., None]))
727730

728731
# TODO(zyfncg) there is a bug of dimensions when slice step > 1 and
729732
# indexs has int type
730733
# self.assertTrue(
731-
# np.array_equal(var[9], np_value[0, 1:10:2, None, None, ...]))
734+
# np.array_equal(var[10], np_value[0, 1:10:2, None, None, ...]))
732735

733736
def _test_for_var(self):
734737
np_value = np.random.random((30, 100, 100)).astype('float32')

0 commit comments

Comments
 (0)