Skip to content

Commit 062134d

Browse files
committed
Save state also for backward pass
1 parent 770c3ea commit 062134d

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/pyronn_torch/conebeam.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,16 @@ def backward(self, projection_grad, state=None, *args):
8585
state.projection_multiplier, volume_grad, *state.volume_origin,
8686
*state.volume_spacing)
8787

88+
self.state = state
8889
if return_none:
8990
return volume_grad, None
9091
else:
9192
return volume_grad,
9293

9394

9495
class _BackwardProjection(torch.autograd.Function):
95-
backward = _ForwardProjection.forward
96-
forward = _ForwardProjection.backward
96+
backward = staticmethod(_ForwardProjection.forward)
97+
forward = staticmethod(_ForwardProjection.backward)
9798

9899

99100
class ConeBeamProjector:

0 commit comments

Comments
 (0)