Skip to content

Commit ef6cb1a

Browse files
Copilotzrr1999
andcommitted
Fix int32 overflow in lstm, lstsq, qr_grad, and spectral_norm_grad impl
Co-authored-by: zrr1999 <46243324+zrr1999@users.noreply.github.com>
1 parent 91b6c2f commit ef6cb1a

File tree

4 files changed

+8
-8
lines changed

4 files changed

+8
-8
lines changed

paddle/phi/kernels/impl/lstm_kernel_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void LSTMKernel(const Context& dev_ctx,
6060
to_batch(dev_ctx, input, batch_gate_new, true, is_reverse);
6161

6262
auto in_dims = input.dims();
63-
int frame_size = static_cast<int>(in_dims[1] / 4);
63+
int64_t frame_size = in_dims[1] / 4;
6464
phi::DDim dims({in_dims[0], frame_size});
6565

6666
if (bias.initialized()) {
@@ -254,7 +254,7 @@ void LSTMGradKernel(const Context& dev_ctx,
254254

255255
auto in_dims = input->dims();
256256
auto out_dims = hidden_g->dims();
257-
int frame_size = static_cast<int>(in_dims[1] / 4);
257+
int64_t frame_size = in_dims[1] / 4;
258258
PADDLE_ENFORCE_EQ(frame_size,
259259
out_dims[1],
260260
common::errors::InvalidArgument(

paddle/phi/kernels/impl/lstsq_kernel_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ inline void GetResidualsTensor(const DeviceContext& dev_ctx,
6666
DenseTensor* rank) {
6767
auto x_dims = x.dims();
6868
int dim_size = x_dims.size();
69-
int m = x_dims[dim_size - 2];
70-
int n = x_dims[dim_size - 1];
69+
int64_t m = x_dims[dim_size - 2];
70+
int64_t n = x_dims[dim_size - 1];
7171

7272
if (m > n && driver != "gelsy") {
7373
bool compute_residuals = true;

paddle/phi/kernels/impl/qr_grad_kernel_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ void QrGradKernel(const Context& dev_ctx,
7676

7777
auto a_dims = A.dims();
7878
int a_rank = a_dims.size();
79-
int m = a_dims[a_rank - 2];
80-
int n = a_dims[a_rank - 1];
79+
int64_t m = a_dims[a_rank - 2];
80+
int64_t n = a_dims[a_rank - 1];
8181

8282
if ((m > n) && (!reduced)) {
8383
PADDLE_THROW(errors::InvalidArgument(

paddle/phi/kernels/impl/spectral_norm_grad_kernel_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ void SpectralNormGradKernel(const Context& dev_ctx,
3131
auto& place = *dev_ctx.eigen_device();
3232
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
3333

34-
const int h = u.dims()[0];
35-
const int w = v.dims()[0];
34+
const int64_t h = u.dims()[0];
35+
const int64_t w = v.dims()[0];
3636

3737
DenseTensor weight_mat, out_grad_mat;
3838
auto dims = weight.dims();

0 commit comments

Comments
 (0)