Fixing Tensor.backward's function signature #1376
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fixes #692
TLDR:
Tensor.backwardhas a different parameter order compared to PyTorch and also swapsretain_graphandcreate_graphin its internal function call.See https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html for backward's function signature:
Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)The current TorchSharp version's function signature is:
Tensor.backward(grad_tensors=null, create_graph=false, retain_graph=false, inputs=null)Note the difference between the ordering of
retain_graphandcreate_graph.Tensor.backwardis just a wrapper totorch.autograd.backwardwhich has a function signature of:autograd.backward(tensors, grad_tensors=null, retain_graph=null, create_graph=false, inputs=null)This means calling
Tensor.backward(retain_graph: true)in TorchSharp is actuallyTensor.backward(create_graph:true)in PyTorch. Same thing forTensor.backward(create_graph: true)actually beingTensor.backward(retain_graph:true).The proposed fix is breaking and would change the
Tensor.backwardfunction signature to match PyTorch. However, nobody noticed for like 2 years anyway and imoretain_graphshould actually meanretain_graph(and same forcreate_graph) 🙂.