6868]
6969
7070
71- def nll_loss (x , label , weight = None , ignore_index = - 100 , reduction = 'mean' ):
71+ def nll_loss (x ,
72+ label ,
73+ weight = None ,
74+ ignore_index = - 100 ,
75+ reduction = 'mean' ,
76+ name = None ):
7277 """
7378 This api returns negative log likelihood.
74- See more detail in `paddle.nn.loss.NLLLoss` .
75-
79+ See more detail in :ref:`api_nn_loss_NLLLoss` .
80+
7681 Parameters:
7782 x (Tensor): Input tensor, the data type is float32, float64.
78- label (Tensor): Label tensor, the data type is int64_t .
83+ label (Tensor): Label tensor, the data type is int64 .
7984 weight (Tensor, optional): Weight tensor, a manual rescaling weight given
8085 to each class. If given, it has to be a Tensor of size `C`. Otherwise,
8186 it treated as if having all ones. the data type is
@@ -85,48 +90,35 @@ def nll_loss(x, label, weight=None, ignore_index=-100, reduction='mean'):
8590 reduction (str, optional): Indicate how to average the loss,
8691 the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
8792 If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
93+ :attr:`reduction` is ``'sum'``, the reduced sum loss is returned;
94+ :attr:`reduction` is ``'none'``, no reduction will be apllied.
8895 Default is ``'mean'``.
89-
96+ name (str, optional): Name for the operation (optional, default is None).
97+ For more information, please refer to :ref:`api_guide_Name`.
98+
9099 Returns:
91100 The tensor variable storing the nll_loss.
92-
101+
93102 Examples:
94- import paddle
95- import numpy as np
96- from paddle.nn.functional import nll_loss
97- log_softmax = paddle.nn.LogSoftmax(axis=1)
98-
99- x_np = np.random.random(size=(10, 10)).astype(np.float32)
100- label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64)
101-
102- place = paddle.CPUPlace()
103-
104- # imperative mode
105- paddle.enable_imperative(place)
106- x = paddle.imperative.to_variable(x_np)
107- log_out = log_softmax(x)
108- label = paddle.imperative.to_variable(label_np)
109- imperative_result = nll_loss(log_out, label)
110- print(imperative_result.numpy())
111-
112- # declarative mode
113- paddle.disable_imperative()
114- prog = paddle.Program()
115- startup_prog = paddle.Program()
116- with paddle.program_guard(prog, startup_prog):
117- x = paddle.nn.data(name='x', shape=[10, 10], dtype='float32')
118- label = paddle.nn.data(name='label', shape=[10], dtype='int64')
119- log_out = log_softmax(x)
120- res = nll_loss(log_out, label)
121-
122- exe = paddle.Executor(place)
123- declaritive_result = exe.run(
124- prog,
125- feed={"x": x_np,
126- "label": label_np},
127- fetch_list=[res])
128- print(declaritive_result)
103+ .. code-block:: python
104+
105+ import paddle
106+ import numpy as np
107+ from paddle.nn.functional import nll_loss
108+ log_softmax = paddle.nn.LogSoftmax(axis=1)
129109
110+ x_np = np.random.random(size=(10, 10)).astype(np.float32)
111+ label_np = np.random.randint(0, 10, size=(10,)).astype(np.int64)
112+
113+ place = paddle.CPUPlace()
114+
115+ # imperative mode
116+ paddle.enable_imperative(place)
117+ x = paddle.imperative.to_variable(x_np)
118+ log_out = log_softmax(x)
119+ label = paddle.imperative.to_variable(label_np)
120+ imperative_result = nll_loss(log_out, label)
121+ print(imperative_result.numpy())
130122 """
131123 if reduction not in ['sum' , 'mean' , 'none' ]:
132124 raise ValueError (
@@ -152,6 +144,7 @@ def nll_loss(x, label, weight=None, ignore_index=-100, reduction='mean'):
152144 out , _ = core .ops .reshape2 (out , 'shape' , out_shape )
153145 return out
154146
147+ helper = LayerHelper ('nll_loss' , ** locals ())
155148 x_shape = list (x .shape )
156149 x_dims = len (x_shape )
157150 if x_dims < 2 :
@@ -165,7 +158,6 @@ def nll_loss(x, label, weight=None, ignore_index=-100, reduction='mean'):
165158 label = paddle .reshape (label , shape = [n , 1 , - 1 ])
166159 out_shape = [n ] + x_shape [2 :]
167160
168- helper = LayerHelper ('nll_loss' , ** locals ())
169161 fluid .data_feeder .check_variable_and_dtype (x , 'x' , ['float32' , 'float64' ],
170162 'nll_loss' )
171163 fluid .data_feeder .check_variable_and_dtype (label , 'label' , ['int64' ],
0 commit comments