Skip to content

Commit 63acc7b

Browse files
committed
fix some cuda memory access error.
1 parent 996a53c commit 63acc7b

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

paddle/fluid/operators/amp/update_loss_scaling_op.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,12 @@ class LazyZeros<platform::CPUDeviceContext, T> {
141141
const bool* found_inf_data,
142142
const std::vector<const framework::Tensor*>& xs,
143143
const std::vector<framework::Tensor*>& outs) const {
144-
if (*found_inf_data) {
145-
VLOG(1) << "-- UpdateLossScaling: Infinite values are found in grads. --";
146-
for (size_t i = 0; i < xs.size(); ++i) {
147-
auto* out = outs[i];
148-
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
149-
int num = out->numel();
144+
for (size_t i = 0; i < xs.size(); ++i) {
145+
auto* out = outs[i];
146+
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
147+
int num = out->numel();
148+
if (*found_inf_data) {
149+
VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --";
150150
std::memset(out_data, 0, num * sizeof(T));
151151
}
152152
}

paddle/fluid/operators/amp/update_loss_scaling_op.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include <thrust/fill.h>
1615
#include <vector>
1716
#include "paddle/fluid/framework/op_registry.h"
1817
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
@@ -34,7 +33,7 @@ __global__ void GpuUpdateLossScaling(
3433
}
3534

3635
template <typename T>
37-
__global__ void FillIf(T* data, const int num, const T& value,
36+
__global__ void FillIf(T* data, const int64_t num, const T value,
3837
const bool* has_inf) {
3938
if (*has_inf) {
4039
int tid = threadIdx.x + blockIdx.x * blockDim.x;
@@ -71,7 +70,7 @@ class LazyZeros<platform::CUDADeviceContext, T> {
7170
for (size_t i = 0; i < xs.size(); ++i) {
7271
auto* out = outs[i];
7372
T* out_data = out->mutable_data<T>(dev_ctx.GetPlace());
74-
int num = out->numel();
73+
int64_t num = out->numel();
7574
int block = 1024;
7675
int grid = (block - 1 + num) / block;
7776
FillIf<<<grid, block, 0, dev_ctx.stream()>>>(

python/paddle/fluid/tests/unittests/test_update_loss_scaling_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def setUp(self):
3535
}
3636

3737
self.outputs = {
38-
'Out': [('out0', np.zeros_like(x))],
38+
'Out': [('out0', x)],
3939
'LossScaling': self.prev_loss_scaling * self.incr_ratio,
4040
'OutGoodSteps': self.zero_steps,
4141
'OutBadSteps': self.zero_steps

0 commit comments

Comments
 (0)