Skip to content

Commit

Permalink
fix gpugraph cuda error;test=develop (PaddlePaddle#56133)
Browse files Browse the repository at this point in the history
  • Loading branch information
danleifeng authored Aug 15, 2023
1 parent 425b96a commit 30cefa2
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids,
int cur_degree,
int step,
int *len_per_row) {
platform::CUDADeviceGuard guard(gpuid_);
auto gpu_graph_ptr = GraphGpuWrapper::GetInstance();
uint64_t node_id = gpu_graph_ptr->edge_to_node_map_[etype_id];
uint8_t edge_src_id = node_id >> 32;
Expand Down Expand Up @@ -2349,6 +2350,7 @@ int GraphDataGenerator::FillWalkBuf() {
break;
}
}
platform::CUDADeviceGuard guard2(gpuid_);
buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());

Expand Down Expand Up @@ -2584,6 +2586,7 @@ int GraphDataGenerator::FillWalkBufMultiPath() {
break;
}
}
platform::CUDADeviceGuard guard2(gpuid_);
buf_state_.Reset(total_row_);
int *d_random_row = reinterpret_cast<int *>(d_random_row_->ptr());

Expand Down

0 comments on commit 30cefa2

Please sign in to comment.