Skip to content

Commit 2cf4707

Browse files
committed
update
1 parent 6a637a3 commit 2cf4707

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

python/paddle/distributed/fleet/base/fleet_base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,15 +1586,16 @@ def unscale_method(self, optimizer):
15861586
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
15871587
param_grads_fp32,
15881588
temp_found_inf_fp32)
1589-
self._found_inf = temp_found_inf_fp16 or temp_found_inf_fp32
1589+
self._found_inf = 1 if temp_found_inf_fp16 or temp_found_inf_fp32 else 0
15901590

15911591
# TODO(shenliang03) Since dp allreduce in the optimizer is
15921592
# after the gradscaler, check_finite needs to synchronize global
15931593
# information. In the future, we should use check_group to speed.
1594-
self._found_inf = paddle.cast(self._found_inf, dtype="int32")
15951594
paddle.distributed.all_reduce(
1596-
self._found_inf, op=paddle.distributed.ReduceOp.MAX, group=None)
1597-
self._found_inf = paddle.cast(self._found_inf, dtype="bool")
1595+
paddle.to_tensor(
1596+
[self._found_inf], dtype="int32"),
1597+
op=paddle.distributed.ReduceOp.MAX,
1598+
group=None)
15981599

15991600
# Only tensor_parallel and pipeline_parallel need to modify scaler
16001601
if self._hcg.get_parallel_mode() in (ParallelMode.TENSOR_PARALLEL,

python/paddle/distributed/fleet/meta_parallel/pp_utils/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,14 @@ def forward(ctx, run_function, all_outputs, *args):
198198

199199
# TODO support AMP
200200
tracer = framework._dygraph_tracer()
201-
ctx.is_fw_autocast = False if tracer._amp_level == 0 else True
202-
ctx.amp_level = 'O2' if tracer._amp_level == 2 else 'O1'
201+
ctx.is_fw_autocast = False if tracer._amp_level == core.AmpLevel.O0 else True
202+
if tracer._amp_level == core.AmpLevel.O2:
203+
ctx.amp_level = 'O2'
204+
elif tracer._amp_level == core.AmpLevel.O1:
205+
ctx.amp_level = 'O1'
206+
else:
207+
raise ValueError("unsupported amp level: {}".format(
208+
tracer._amp_level))
203209
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
204210

205211
with paddle.no_grad():

0 commit comments

Comments
 (0)