11#pragma once
22#include < cstddef>
33#include < stdexcept>
4+ #include < string>
45
56#include < ttl/bits/std_cuda_runtime.hpp>
67#include < ttl/bits/std_device.hpp>
@@ -10,6 +11,23 @@ namespace ttl
1011{
1112namespace internal
1213{
14+ class std_cuda_error_checker_t
15+ {
16+ const std::string func_name_;
17+
18+ public:
19+ std_cuda_error_checker_t (const char *func_name) : func_name_(func_name) {}
20+
21+ void operator <<(const cudaError_t err) const
22+ {
23+ if (err != cudaSuccess) {
24+ throw std::runtime_error (func_name_ + " failed with: " +
25+ std::to_string (static_cast <int >(err)) +
26+ " : " + cudaGetErrorString (err));
27+ }
28+ }
29+ }; // namespace ttl
30+
1331struct cuda_copier {
1432 static constexpr auto h2d = cudaMemcpyHostToDevice;
1533 static constexpr auto d2h = cudaMemcpyDeviceToHost;
@@ -18,10 +36,8 @@ struct cuda_copier {
1836 template <cudaMemcpyKind dir>
1937 static void copy (void *dst, const void *src, size_t size)
2038 {
21- const cudaError_t err = cudaMemcpy (dst, src, size, dir);
22- if (err != cudaSuccess) {
23- throw std::runtime_error (" cudaMemcpy failed" );
24- }
39+ static std_cuda_error_checker_t check (" cudaMemcpy" );
40+ check << cudaMemcpy (dst, src, size, dir);
2541 }
2642};
2743
@@ -54,10 +70,8 @@ class basic_allocator<R, cuda_memory>
5470 void *deviceMem;
5571 // cudaMalloc<R>(&deviceMem, count);
5672 // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY
57- const cudaError_t err = cudaMalloc (&deviceMem, count * sizeof (R));
58- if (err != cudaSuccess) {
59- throw std::runtime_error (" cudaMalloc failed" );
60- }
73+ static std_cuda_error_checker_t check (" cudaMalloc" );
74+ check << cudaMalloc (&deviceMem, count * sizeof (R));
6175 return reinterpret_cast <R *>(deviceMem);
6276 }
6377};
@@ -68,8 +82,8 @@ class basic_deallocator<R, cuda_memory>
6882 public:
6983 void operator ()(R *data)
7084 {
71- const cudaError_t err = cudaFree (data );
72- if (err != cudaSuccess) { throw std::runtime_error ( " cudaFree failed " ); }
85+ static std_cuda_error_checker_t check ( " cudaFree " );
86+ check << cudaFree (data);
7387 }
7488};
7589} // namespace internal
0 commit comments