diff --git a/src/cuda/fusion/node2vec.cu b/src/cuda/fusion/node2vec.cu index 174d1ee..ec090e3 100644 --- a/src/cuda/fusion/node2vec.cu +++ b/src/cuda/fusion/node2vec.cu @@ -37,6 +37,13 @@ __device__ bool BinarySearch(int64_t* ptr, int64_t degree, int64_t target) { __device__ bool CheckConnect(int64_t* graph_indice, int64_t* graph_indptr, int64_t degree, int64_t src, int64_t dst) { + int64_t item = cub::UpperBound(graph_indice + graph_indptr[src], degree, dst); + if (item == degree) { + return false; + } else { + return true; + } + /* if (BinarySearch(graph_indice + graph_indptr[src], degree, dst)) { // paster() // printf("Connect %d %d \n", src, dst); @@ -44,6 +51,7 @@ __device__ bool CheckConnect(int64_t* graph_indice, int64_t* graph_indptr, } // printf("not Connect %d %d \n", src, dst); return false; + */ } __global__ void _Node2VecKernel(const int64_t* seed_data, diff --git a/src/cuda/sort_indices.cu b/src/cuda/sort_indices.cu index 459929d..e2894e6 100644 --- a/src/cuda/sort_indices.cu +++ b/src/cuda/sort_indices.cu @@ -14,11 +14,14 @@ torch::Tensor SortIndicesCUDA(torch::Tensor indptr, torch::Tensor indices) { torch::Tensor sorted_indices = torch::empty_like(indices); + bool use_uva = false; if (sorted_indices.device().type() != torch::kCUDA && !sorted_indices.is_pinned()) { sorted_indices = sorted_indices.pin_memory(); } + if (sorted_indices.is_pinned()) use_uva = true; + void* d_temp_storage = NULL; size_t temp_storage_bytes = 0; @@ -27,14 +30,28 @@ torch::Tensor SortIndicesCUDA(torch::Tensor indptr, torch::Tensor indices) { sorted_indices.data_ptr(), num_items, num_segments, indptr.data_ptr(), indptr.data_ptr() + 1); - CUDA_CALL(cudaMallocManaged(&d_temp_storage, temp_storage_bytes)); + if (use_uva) { + d_temp_storage = malloc(temp_storage_bytes); + CUDA_CALL(cudaHostRegister(d_temp_storage, temp_storage_bytes, + cudaHostRegisterDefault)); + // CUDA_CALL(cudaHostAlloc(&d_temp_storage, temp_storage_bytes, + // cudaHostAllocDefault)); + } else { + CUDA_CALL(cudaMallocManaged(&d_temp_storage, temp_storage_bytes)); + } cub::DeviceSegmentedRadixSort::SortKeys( d_temp_storage, temp_storage_bytes, indices.data_ptr(), sorted_indices.data_ptr(), num_items, num_segments, indptr.data_ptr(), indptr.data_ptr() + 1); - CUDA_CALL(cudaFree(d_temp_storage)); + if (use_uva) { + CUDA_CALL(cudaHostUnregister(d_temp_storage)); + free(d_temp_storage); + // CUDA_CALL(cudaFreeHost(d_temp_storage)); + } else { + CUDA_CALL(cudaFree(d_temp_storage)); + } return sorted_indices; }