Skip to content

Commit bbd200c

Browse files
committed
Fix gradient calculation in ParallelProjector
1 parent 6a3dee3 commit bbd200c

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

src/pyronn_torch/parallel.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,18 @@ def project_forward(self, volume):
115115
elif volume.shape[1] != 1:
116116
raise ValueError('Only channel dimension of 1 is currently supported!')
117117

118-
projs = torch.zeros(volume.shape[0],
119-
self._projection_shape[0],
120-
self._projection_shape[1], device='cuda',
121-
requires_grad=volume.requires_grad)
122-
118+
projs = []
123119
for i, slice in enumerate(volume):
124-
projs[i] = _ForwardProjection().apply(slice[0], State(
120+
projs.append(_ForwardProjection().apply(slice[0], State(
125121
self._detector_origin,
126122
self._detector_spacing,
127123
self._projection_shape,
128124
self._ray_vectors,
129125
self._volume_origin,
130126
self._volume_shape,
131127
self._volume_spacing
132-
))
133-
return projs
128+
)))
129+
return torch.stack(projs, axis=0)
134130

135131
def project_backward(self, projection):
136132
projection = projection.float().contiguous().cuda()

0 commit comments

Comments
 (0)