@@ -218,20 +218,20 @@ TORCH_LIBRARY_IMPL(torchmdnet_extensions, CompositeImplicitAutograd, m) {
218
218
});
219
219
}
220
220
221
- // // Explicit device backend registrations for PyTorch versions that do not
222
- // // automatically fall back to CompositeImplicitAutograd for device dispatch.
223
- // TORCH_LIBRARY_IMPL(torchmdnet_extensions, CPU, m) {
224
- // m.impl("get_neighbor_pairs",
225
- // [](const std::string& strategy, const Tensor& positions, const Tensor& batch,
226
- // const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower,
227
- // const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
228
- // bool include_transpose) {
229
- // auto result = NeighborAutograd::apply(strategy, positions, batch, box_vectors,
230
- // use_periodic, cutoff_lower, cutoff_upper,
231
- // max_num_pairs, loop, include_transpose);
232
- // return std::make_tuple(result[0], result[1], result[2], result[3]);
233
- // });
234
- // }
221
+ // Explicit device backend registrations for PyTorch versions that do not
222
+ // automatically fall back to CompositeImplicitAutograd for device dispatch.
223
+ TORCH_LIBRARY_IMPL (torchmdnet_extensions, CPU, m) {
224
+ m.impl (" get_neighbor_pairs" ,
225
+ [](const std::string& strategy, const Tensor& positions, const Tensor& batch,
226
+ const Tensor& box_vectors, bool use_periodic, const Scalar& cutoff_lower,
227
+ const Scalar& cutoff_upper, const Scalar& max_num_pairs, bool loop,
228
+ bool include_transpose) {
229
+ auto result = NeighborAutograd::apply (strategy, positions, batch, box_vectors,
230
+ use_periodic, cutoff_lower, cutoff_upper,
231
+ max_num_pairs, loop, include_transpose);
232
+ return std::make_tuple (result[0 ], result[1 ], result[2 ], result[3 ]);
233
+ });
234
+ }
235
235
236
236
// TORCH_LIBRARY_IMPL(torchmdnet_extensions, CUDA, m) {
237
237
// m.impl("get_neighbor_pairs",
0 commit comments