@@ -35,7 +35,7 @@ limitations under the License.
3535// . [DONE] spatial gradient mode (without multiplication with output gradient)
3636// . [DONE] second order gradients (backward pass for spatial gradients)
3737// . performance tests
38- // . input bound/inter are always vectors -> clean unused constructors
38+ // . [DONE] input bound/inter are always vectors -> clean unused constructors
3939
4040#include < ATen/ATen.h>
4141#include < limits>
@@ -2149,39 +2149,42 @@ MONAI_NAMESPACE_DEVICE { // cuda
21492149 // FUNCTIONAL FORM WITH DISPATCH
21502150 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21512151
2152- #define PUSHPULL_INSTANTIATE3 (BoundType0, InterpolationType0, SourceType0 ) \
2153- template std::deque<Tensor> pushpull ( \
2154- const SourceType0&, \
2155- const Tensor&, \
2156- const Tensor&, \
2157- BoundType0, \
2158- InterpolationType0, \
2159- bool , \
2160- bool , \
2161- bool , \
2162- bool , \
2163- bool , \
2164- bool ); \
2165- template std::deque<Tensor> pushpull ( \
2166- const SourceType0&, const Tensor&, BoundType0, InterpolationType0, bool , bool , bool , bool , bool , bool )
2167- #define PUSHPULL_INSTANTIATE2 (BoundType0, InterpolationType0 ) \
2168- PUSHPULL_INSTANTIATE3 (BoundType0, InterpolationType0, IntArrayRef); \
2169- PUSHPULL_INSTANTIATE3 (BoundType0, InterpolationType0, Tensor)
2170- #define PUSHPULL_INSTANTIATE1 (BoundType0 ) \
2171- PUSHPULL_INSTANTIATE2 (BoundType0, InterpolationType); \
2172- PUSHPULL_INSTANTIATE2 (BoundType0, InterpolationVectorRef)
2173- #define PUSHPULL_INSTANTIATE \
2174- PUSHPULL_INSTANTIATE1 (BoundType); \
2175- PUSHPULL_INSTANTIATE1 (BoundVectorRef)
2152+ #define PUSHPULL_INSTANTIATE_SOURCE (SourceType ) \
2153+ template std::deque<Tensor> pushpull ( \
2154+ const SourceType&, \
2155+ const Tensor&, \
2156+ const Tensor&, \
2157+ BoundVectorRef, \
2158+ InterpolationVectorRef, \
2159+ bool , \
2160+ bool , \
2161+ bool , \
2162+ bool , \
2163+ bool ); \
2164+ template std::deque<Tensor> pushpull ( \
2165+ const SourceType&, \
2166+ const Tensor&, \
2167+ BoundVectorRef, \
2168+ InterpolationVectorRef, \
2169+ bool , \
2170+ bool , \
2171+ bool , \
2172+ bool , \
2173+ bool , \
2174+ bool )
2175+
2176+ #define PUSHPULL_INSTANTIATE \
2177+ PUSHPULL_INSTANTIATE_SOURCE (IntArrayRef); \
2178+ PUSHPULL_INSTANTIATE_SOURCE (Tensor)
21762179
21772180 // Two arguments (source, grid)
2178- // > ` bound` and ` interpolation` can be single arguments or vectors .
2179- template <typename BoundType, typename InterpolationType, typename SourceType>
2181+ // > bound and interpolation are strictly VectorRef .
2182+ template <typename SourceType>
21802183 MONAI_HOST std::deque<Tensor> pushpull(
21812184 const SourceType& source,
21822185 const Tensor& grid,
2183- BoundType bound,
2184- InterpolationType interpolation,
2186+ BoundVectorRef bound,
2187+ InterpolationVectorRef interpolation,
21852188 bool extrapolate,
21862189 bool do_pull,
21872190 bool do_push,
@@ -2206,15 +2209,14 @@ MONAI_NAMESPACE_DEVICE { // cuda
22062209 }
22072210
22082211 // Three arguments (source, grid, target)
2209- // > `bound` and `interpolation` can be single arguments or vectors.
2210- // > `source` can be a tensor or a vector of dimensions.
2211- template <typename BoundType, typename InterpolationType, typename SourceType>
2212+ // > bound and interpolation are strictly VectorRef.
2213+ template <typename SourceType>
22122214 MONAI_HOST std::deque<Tensor> pushpull (
22132215 const SourceType& source,
22142216 const Tensor& grid,
22152217 const Tensor& target,
2216- BoundType bound,
2217- InterpolationType interpolation,
2218+ BoundVectorRef bound,
2219+ InterpolationVectorRef interpolation,
22182220 bool extrapolate,
22192221 bool do_pull,
22202222 bool do_push,
0 commit comments