Skip to content

Commit 9697618

Browse files
committed
Fix the compiling error of update_loss_scaling when using cuda9.
1 parent b3fa899 commit 9697618

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

paddle/fluid/operators/amp/update_loss_scaling_op.h

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#pragma once
1616

17+
#if defined(PADDLE_WITH_CUDA) && defined(__NVCC__)
18+
#include <cuda.h>
19+
#endif // PADDLE_WITH_CUDA && __NVCC__
1720
#include <cmath>
1821
#include <vector>
1922
#include "paddle/fluid/framework/operator.h"
@@ -29,13 +32,23 @@ namespace operators {
2932
using Tensor = framework::Tensor;
3033

3134
template <typename T>
32-
HOSTDEVICE void Update(const bool* found_inf_data,
33-
const T* pre_loss_scaling_data, const int* good_in_data,
34-
const int* bad_in_data, const int incr_every_n_steps,
35-
const int decr_every_n_nan_or_inf,
36-
const float incr_ratio, const float decr_ratio,
37-
T* updated_loss_scaling_data, int* good_out_data,
38-
int* bad_out_data) {
35+
inline HOSTDEVICE bool check_finite(T value) {
36+
#if defined(PADDLE_WITH_CUDA) && defined(__NVCC__)
37+
return isfinite(value);
38+
#else
39+
return std::isfinite(value);
40+
#endif
41+
}
42+
43+
template <typename T>
44+
inline HOSTDEVICE void Update(const bool* found_inf_data,
45+
const T* pre_loss_scaling_data,
46+
const int* good_in_data, const int* bad_in_data,
47+
const int incr_every_n_steps,
48+
const int decr_every_n_nan_or_inf,
49+
const float incr_ratio, const float decr_ratio,
50+
T* updated_loss_scaling_data, int* good_out_data,
51+
int* bad_out_data) {
3952
if (*found_inf_data) {
4053
*good_out_data = 0;
4154
*bad_out_data = *bad_in_data + 1;
@@ -51,7 +64,7 @@ HOSTDEVICE void Update(const bool* found_inf_data,
5164
*good_out_data = *good_in_data + 1;
5265
if (*good_out_data == incr_every_n_steps) {
5366
T new_loss_scaling = *pre_loss_scaling_data * incr_ratio;
54-
*updated_loss_scaling_data = std::isfinite(new_loss_scaling)
67+
*updated_loss_scaling_data = check_finite(new_loss_scaling)
5568
? new_loss_scaling
5669
: *pre_loss_scaling_data;
5770
*good_out_data = 0;

0 commit comments

Comments
 (0)