Skip to content

Commit 3a0b0b1

Browse files
committed
fix top_k GitDims error. test=develop
1 parent b9e0025 commit 3a0b0b1

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

paddle/fluid/operators/top_k_v2_op.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
#include "paddle/fluid/framework/eigen.h"
1616
#include "paddle/fluid/framework/op_registry.h"
17-
#include "paddle/fluid/operators/p_norm_op.h"
1817
#include "paddle/fluid/operators/top_k_function_cuda.h"
1918
#include "paddle/fluid/operators/top_k_v2_op.h"
2019

paddle/fluid/operators/top_k_v2_op.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,19 @@ limitations under the License. */
3333
namespace paddle {
3434
namespace operators {
3535

36+
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
37+
int* post) {
38+
*pre = 1;
39+
*post = 1;
40+
*n = dim[axis];
41+
for (int i = 0; i < axis; ++i) {
42+
(*pre) *= dim[i];
43+
}
44+
for (int i = axis + 1; i < dim.size(); ++i) {
45+
(*post) *= dim[i];
46+
}
47+
}
48+
3649
template <typename T, typename Type>
3750
static void FullTopK(Type input_height, Type input_width, int input_dim,
3851
const framework::Tensor* input, T* t_out, Type* t_indices,

0 commit comments

Comments
 (0)