Skip to content

Commit

Permalink
update node2vec
Browse files Browse the repository at this point in the history
  • Loading branch information
gpzlx1 committed Jul 30, 2023
1 parent 1692081 commit 31be8b1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
8 changes: 8 additions & 0 deletions src/cuda/fusion/node2vec.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,21 @@ __device__ bool BinarySearch(int64_t* ptr, int64_t degree, int64_t target) {

__device__ bool CheckConnect(int64_t* graph_indice, int64_t* graph_indptr,
int64_t degree, int64_t src, int64_t dst) {
int64_t item = cub::UpperBound(graph_indice + graph_indptr[src], degree, dst);
if (item == degree) {
return false;
} else {
return true;
}
/*
if (BinarySearch(graph_indice + graph_indptr[src], degree, dst)) {
// paster()
// printf("Connect %d %d \n", src, dst);
return true;
}
// printf("not Connect %d %d \n", src, dst);
return false;
*/
}

__global__ void _Node2VecKernel(const int64_t* seed_data,
Expand Down
21 changes: 19 additions & 2 deletions src/cuda/sort_indices.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ torch::Tensor SortIndicesCUDA(torch::Tensor indptr, torch::Tensor indices) {

torch::Tensor sorted_indices = torch::empty_like(indices);

bool use_uva = false;
if (sorted_indices.device().type() != torch::kCUDA &&
!sorted_indices.is_pinned()) {
sorted_indices = sorted_indices.pin_memory();
}

if (sorted_indices.is_pinned()) use_uva = true;

void* d_temp_storage = NULL;
size_t temp_storage_bytes = 0;

Expand All @@ -27,14 +30,28 @@ torch::Tensor SortIndicesCUDA(torch::Tensor indptr, torch::Tensor indices) {
sorted_indices.data_ptr<int64_t>(), num_items, num_segments,
indptr.data_ptr<int64_t>(), indptr.data_ptr<int64_t>() + 1);

CUDA_CALL(cudaMallocManaged(&d_temp_storage, temp_storage_bytes));
if (use_uva) {
d_temp_storage = malloc(temp_storage_bytes);
CUDA_CALL(cudaHostRegister(d_temp_storage, temp_storage_bytes,
cudaHostRegisterDefault));
// CUDA_CALL(cudaHostAlloc(&d_temp_storage, temp_storage_bytes,
// cudaHostAllocDefault));
} else {
CUDA_CALL(cudaMallocManaged(&d_temp_storage, temp_storage_bytes));
}

cub::DeviceSegmentedRadixSort::SortKeys(
d_temp_storage, temp_storage_bytes, indices.data_ptr<int64_t>(),
sorted_indices.data_ptr<int64_t>(), num_items, num_segments,
indptr.data_ptr<int64_t>(), indptr.data_ptr<int64_t>() + 1);

CUDA_CALL(cudaFree(d_temp_storage));
if (use_uva) {
CUDA_CALL(cudaHostUnregister(d_temp_storage));
free(d_temp_storage);
// CUDA_CALL(cudaFreeHost(d_temp_storage));
} else {
CUDA_CALL(cudaFree(d_temp_storage));
}
return sorted_indices;
}

Expand Down

0 comments on commit 31be8b1

Please sign in to comment.