|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | from paddle.fluid import framework
|
| 16 | +from .utils import _check_tensors, _stack_tensor_or_return_none, _replace_none_with_zero_tensor |
16 | 17 | import paddle
|
17 | 18 |
|
18 | 19 |
|
19 |
| -def _check_tensors(in_out_list, name): |
20 |
| - assert in_out_list is not None, "{} should not be None".format(name) |
21 |
| - |
22 |
| - if isinstance(in_out_list, (list, tuple)): |
23 |
| - assert len(in_out_list) > 0, "{} connot be empyt".format(name) |
24 |
| - for each_var in in_out_list: |
25 |
| - assert isinstance( |
26 |
| - each_var, |
27 |
| - paddle.Tensor), "Elements of {} must be paddle.Tensor".format( |
28 |
| - name) |
29 |
| - return in_out_list |
30 |
| - else: |
31 |
| - assert isinstance( |
32 |
| - in_out_list, |
33 |
| - paddle.Tensor), "{} must be Tensor or list of Tensor".format(name) |
34 |
| - return [in_out_list] |
35 |
| - |
36 |
| - |
37 |
| -def _stack_tensor_or_return_none(origin_list): |
38 |
| - assert len(origin_list) > 0, "Can't not stack an empty list" |
39 |
| - return paddle.stack( |
40 |
| - origin_list, axis=0) if isinstance(origin_list[0], |
41 |
| - paddle.Tensor) else None |
42 |
| - |
43 |
| - |
44 | 20 | @framework.dygraph_only
|
45 | 21 | def jacobian(func, inputs, create_graph=False, allow_unused=False):
|
46 | 22 | '''
|
@@ -183,3 +159,129 @@ def func(x, y):
|
183 | 159 | return jacobian[0]
|
184 | 160 | else:
|
185 | 161 | return jacobian
|
| 162 | + |
| 163 | + |
| 164 | +@framework.dygraph_only |
| 165 | +def hessian(func, inputs, create_graph=False, allow_unused=False): |
| 166 | + ''' |
| 167 | + .. note:: |
| 168 | + **This API is ONLY available in imperative mode.** |
| 169 | +
|
| 170 | + This API computes the Hessian matrix of `func` with respect to `inputs`. |
| 171 | +
|
| 172 | + Parameters: |
| 173 | + func (function): a Python function that takes a Tensor or a Tensor |
| 174 | + list/tuple as inputs and returns a Tensor with a single element. |
| 175 | + inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or |
| 176 | + Tensor list/tuple of the function ``func``. |
| 177 | + create_graph (bool, optional): whether to create the gradient graphs |
| 178 | + of the computing process. When it is True, higher order derivatives |
| 179 | + are supported to compute; when it is False, the gradient graphs of |
| 180 | + the computing process would be discarded. Defaults to ``False``. |
| 181 | + allow_unused (bool, optional): whether to raise error or return None if |
| 182 | + some Tensors of `inputs` are unreachable in the graph. Error would |
| 183 | + be raised if allow_unused=False, and None would be returned as |
| 184 | + their gradients if allow_unused=True. Default False. |
| 185 | + Returns: |
| 186 | + Hessian (Tensor or a tuple of tuple of Tensors): if function ``func`` |
| 187 | + takes a Tensor as ``inputs``, Hessian will be a single Tensor containing |
| 188 | + the Hessian matrix for the linearized ``inputs`` Tensor. If function |
| 189 | + ``func`` takes a Tensor list/tuple as ``inputs``, then the Hessian will |
| 190 | + be a tuple of tuple of Tensors where ``Hessian[i][j]`` will contain the |
| 191 | + Hessian matrix of the ``i``th input and ``j``th input with size ``m * n``. |
| 192 | + Here ``m`` and ``n`` denote the number of elements of the ``i`` th input |
| 193 | + and the ``j`` th input respectively. |
| 194 | +
|
| 195 | + Examples 1: |
| 196 | + .. code-block:: python |
| 197 | +
|
| 198 | + import paddle |
| 199 | +
|
| 200 | + def func(x): |
| 201 | + return paddle.sum(paddle.matmul(x, x)) |
| 202 | + |
| 203 | + x = paddle.ones(shape=[2, 2], dtype='float32') |
| 204 | + x.stop_gradient = False |
| 205 | + hessian = paddle.autograd.hessian(func, x) |
| 206 | + print(hessian) |
| 207 | + # Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, |
| 208 | + # [[2., 1., 1., 0.], |
| 209 | + # [1., 0., 2., 1.], |
| 210 | + # [1., 2., 0., 1.], |
| 211 | + # [0., 1., 1., 2.]]) |
| 212 | +
|
| 213 | + Examples 2: |
| 214 | + .. code-block:: python |
| 215 | +
|
| 216 | + import paddle |
| 217 | +
|
| 218 | + def func(x, y): |
| 219 | + return paddle.sum(paddle.matmul(x, y)) |
| 220 | + |
| 221 | + x = paddle.ones(shape=[2, 2], dtype='float32') |
| 222 | + y = paddle.ones(shape=[2, 2], dtype='float32') |
| 223 | + x.stop_gradient = False |
| 224 | + y.stop_gradient = False |
| 225 | + hessian = paddle.autograd.hessian(func, [x, y]) |
| 226 | + print(hessian) |
| 227 | + # ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, |
| 228 | + # [[0., 0., 0., 0.], |
| 229 | + # [0., 0., 0., 0.], |
| 230 | + # [0., 0., 0., 0.], |
| 231 | + # [0., 0., 0., 0.]]), |
| 232 | + # Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, |
| 233 | + # [[1., 1., 0., 0.], |
| 234 | + # [0., 0., 1., 1.], |
| 235 | + # [1., 1., 0., 0.], |
| 236 | + # [0., 0., 1., 1.]])), |
| 237 | + # (Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, |
| 238 | + # [[1., 0., 1., 0.], |
| 239 | + # [1., 0., 1., 0.], |
| 240 | + # [0., 1., 0., 1.], |
| 241 | + # [0., 1., 0., 1.]]), |
| 242 | + # Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, |
| 243 | + # [[0., 0., 0., 0.], |
| 244 | + # [0., 0., 0., 0.], |
| 245 | + # [0., 0., 0., 0.], |
| 246 | + # [0., 0., 0., 0.]]))) |
| 247 | +
|
| 248 | + Examples 3: |
| 249 | + .. code-block:: python |
| 250 | +
|
| 251 | + import paddle |
| 252 | +
|
| 253 | + def func(x, y): |
| 254 | + return paddle.sum(paddle.matmul(x, x)) |
| 255 | + |
| 256 | + x = paddle.ones(shape=[2, 2], dtype='float32') |
| 257 | + y = paddle.ones(shape=[2, 2], dtype='float32') |
| 258 | + x.stop_gradient = False |
| 259 | + y.stop_gradient = False |
| 260 | + hessian = paddle.autograd.hessian(func, [x, y], allow_unused=True) |
| 261 | + print(hessian) |
| 262 | + # ((Tensor(shape=[4, 4], dtype=float32, place=CUDAPlace(0), stop_gradient=True, |
| 263 | + # [[2., 1., 1., 0.], |
| 264 | + # [1., 0., 2., 1.], |
| 265 | + # [1., 2., 0., 1.], |
| 266 | + # [0., 1., 1., 2.]]), None), (None, None)) |
| 267 | +
|
| 268 | + ''' |
| 269 | + inputs = _check_tensors(inputs, "inputs") |
| 270 | + outputs = func(*inputs) |
| 271 | + assert isinstance(outputs, paddle.Tensor) and outputs.shape == [ |
| 272 | + 1 |
| 273 | + ], "The function to compute Hessian matrix should return a Tensor with a single element" |
| 274 | + |
| 275 | + def jac_func(*ins): |
| 276 | + grad_inputs = paddle.grad( |
| 277 | + outputs, |
| 278 | + ins, |
| 279 | + create_graph=True, |
| 280 | + retain_graph=True, |
| 281 | + allow_unused=allow_unused) |
| 282 | + return tuple( |
| 283 | + _replace_none_with_zero_tensor(grad_inputs[i], inputs[i]) |
| 284 | + for i in range(len(inputs))) |
| 285 | + |
| 286 | + return jacobian( |
| 287 | + jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused) |
0 commit comments