diff --git a/model.cpp b/model.cpp index ae02d2b..bc391ae 100644 --- a/model.cpp +++ b/model.cpp @@ -6,6 +6,12 @@ #include "tensor_math.hpp" #include "gsplat.hpp" +#ifdef USE_HIP +#include +#elif defined(USE_CUDA) +#include +#endif + torch::Tensor randomQuatTensor(long long n){ torch::Tensor u = torch::rand(n); torch::Tensor v = torch::rand(n); @@ -442,6 +448,11 @@ void Model::afterTrain(int step){ xysGradNorm = torch::Tensor(); visCounts = torch::Tensor(); max2DSize = torch::Tensor(); +#ifdef USE_HIP + c10::hip::HIPCachingAllocator::emptyCache(); +#elif defined(USE_CUDA) + c10::cuda::CUDACachingAllocator::emptyCache(); +#endif } }