23
23
24
24
#include < stdexcept>
25
25
#include < torch/script.h>
26
- #include < c10/cuda/CUDAStream.h>
27
26
#include " CpuANISymmetryFunctions.h"
27
+ #ifdef ENABLE_CUDA
28
+ #include < c10/cuda/CUDAStream.h>
28
29
#include " CudaANISymmetryFunctions.h"
29
30
30
31
#define CHECK_CUDA_RESULT (result ) \
31
32
if (result != cudaSuccess) { \
32
33
throw std::runtime_error (std::string (" Encountered error " )+cudaGetErrorName (result)+" at " +__FILE__+" :" +std::to_string (__LINE__));\
33
34
}
35
+ #endif
34
36
35
37
namespace NNPOps {
36
38
namespace ANISymmetryFunctions {
@@ -87,17 +89,23 @@ class Holder : public torch::CustomClassHolder {
87
89
const torch::Device& device = tensorOptions.device ();
88
90
if (device.is_cpu ())
89
91
symFunc = std::make_shared<CpuANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false , atomSpecies, radialFunctions, angularFunctions, true );
90
- if (device.is_cuda ()) {
92
+ #ifdef ENABLE_CUDA
93
+ else if (device.is_cuda ()) {
91
94
// PyTorch allow to chose GPU with "torch.device", but it doesn't set as the default one.
92
95
CHECK_CUDA_RESULT (cudaSetDevice (device.index ()));
93
96
symFunc = std::make_shared<CudaANISymmetryFunctions>(numAtoms, numSpecies, Rcr, Rca, false , atomSpecies, radialFunctions, angularFunctions, true );
94
97
}
98
+ #endif
99
+ else
100
+ throw std::runtime_error (" Unsupported device: " + device.str ());
95
101
96
102
radial = torch::empty ({numAtoms, numSpecies * (int )radialFunctions.size ()}, tensorOptions);
97
103
angular = torch::empty ({numAtoms, numSpecies * (numSpecies + 1 ) / 2 * (int )angularFunctions.size ()}, tensorOptions);
98
104
positionsGrad = torch::empty ({numAtoms, 3 }, tensorOptions);
99
105
106
+ #ifdef ENABLE_CUDA
100
107
cudaSymFunc = dynamic_cast <CudaANISymmetryFunctions*>(symFunc.get ());
108
+ #endif
101
109
};
102
110
103
111
tensor_list forward (const Tensor& positions_, const optional<Tensor>& periodicBoxVectors_) {
@@ -111,10 +119,12 @@ class Holder : public torch::CustomClassHolder {
111
119
float * periodicBoxVectorsPtr = periodicBoxVectors.data_ptr <float >();
112
120
}
113
121
122
+ #ifdef ENABLE_CUDA
114
123
if (cudaSymFunc) {
115
124
const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream (tensorOptions.device ().index ());
116
125
cudaSymFunc->setStream (stream.stream ());
117
126
}
127
+ #endif
118
128
119
129
symFunc->computeSymmetryFunctions (positions.data_ptr <float >(), periodicBoxVectorsPtr, radial.data_ptr <float >(), angular.data_ptr <float >());
120
130
@@ -126,10 +136,12 @@ class Holder : public torch::CustomClassHolder {
126
136
const Tensor radialGrad = grads[0 ].clone ();
127
137
const Tensor angularGrad = grads[1 ].clone ();
128
138
139
+ #ifdef ENABLE_CUDA
129
140
if (cudaSymFunc) {
130
141
const torch::cuda::CUDAStream stream = torch::cuda::getCurrentCUDAStream (tensorOptions.device ().index ());
131
142
cudaSymFunc->setStream (stream.stream ());
132
143
}
144
+ #endif
133
145
134
146
symFunc->backprop (radialGrad.data_ptr <float >(), angularGrad.data_ptr <float >(), positionsGrad.data_ptr <float >());
135
147
@@ -146,7 +158,9 @@ class Holder : public torch::CustomClassHolder {
146
158
Tensor radial;
147
159
Tensor angular;
148
160
Tensor positionsGrad;
161
+ #ifdef ENABLE_CUDA
149
162
CudaANISymmetryFunctions* cudaSymFunc;
163
+ #endif
150
164
};
151
165
152
166
class AutogradFunctions : public torch ::autograd::Function<AutogradFunctions> {
0 commit comments