Skip to content

Commit 489a64e

Browse files
authored
fix cross_entropy when run static graph mode of mlu and npu (PaddlePaddle#40621)
1 parent cb8afc2 commit 489a64e

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

python/paddle/nn/functional/loss.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,12 +1818,16 @@ def cross_entropy(input,
18181818
helper = LayerHelper('softmax_with_cross_entropy', **locals())
18191819
softmax = helper.create_variable_for_type_inference(dtype=input.dtype)
18201820
out = helper.create_variable_for_type_inference(dtype=input.dtype)
1821+
1822+
outputs = {'Softmax': softmax, 'Loss': out}
1823+
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
1824+
backprop = helper.create_variable_for_type_inference(dtype=input.dtype)
1825+
outputs['Backprop'] = backprop
18211826
helper.append_op(
18221827
type='softmax_with_cross_entropy',
18231828
inputs={'Logits': input,
18241829
'Label': label},
1825-
outputs={'Softmax': softmax,
1826-
'Loss': out},
1830+
outputs=outputs,
18271831
attrs=attrs)
18281832

18291833
if weight is not None:

0 commit comments

Comments
 (0)