@@ -25,15 +25,16 @@ namespace operators {
2525
2626using Tensor = framework::Tensor;
2727
28- void UpdateAttr (const framework::DDim in_dims, const std::vector<int > axes,
28+ void UpdateAttr (const framework::DDim& in_dims, const std::vector<int > axes,
2929 const std::vector<int > starts, const std::vector<int > ends,
3030 std::vector<int >* offsets, std::vector<int >* size) {
3131 int cnt = 0 ;
3232 for (int i = 0 ; i < in_dims.size (); ++i) {
3333 int start = 0 ;
3434 int end = in_dims[i];
35- int axis = axes[cnt];
36-
35+ // NOTE(zhiqiu): Becareful that cnt may > axes.size() and result in
36+ // overflow.
37+ int axis = cnt < static_cast <int >(axes.size ()) ? axes[cnt] : -1 ;
3738 if (axis == i) {
3839 start = starts[cnt];
3940 if (start < 0 ) {
@@ -63,10 +64,10 @@ class SliceNPUKernel : public framework::OpKernel<T> {
6364 auto axes = ctx.Attr <std::vector<int >>(" axes" );
6465 auto starts = ctx.Attr <std::vector<int >>(" starts" );
6566 auto ends = ctx.Attr <std::vector<int >>(" ends" );
67+ const auto & in_dims = input->dims ();
6668
6769 out->mutable_data <T>(ctx.GetPlace ());
6870
69- auto in_dims = input->dims ();
7071 std::vector<int > offsets (in_dims.size ());
7172 std::vector<int > size (in_dims.size ());
7273
@@ -93,8 +94,7 @@ class SliceGradNPUKernel : public framework::OpKernel<T> {
9394 auto axes = ctx.Attr <std::vector<int >>(" axes" );
9495 auto starts = ctx.Attr <std::vector<int >>(" starts" );
9596 auto ends = ctx.Attr <std::vector<int >>(" ends" );
96-
97- auto in_dims = input->dims ();
97+ const auto & in_dims = input->dims ();
9898 int rank = in_dims.size ();
9999
100100 std::vector<int > offsets (rank);
0 commit comments