Skip to content

Commit

Permalink
run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed May 31, 2022
1 parent 1c2d8d4 commit c51530a
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ __global__ void copy_buffer_ac_to_final_place(

template <int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample_kernel(GpuPsCommGraph graph,
int64_t* node_index,
int* actual_size, uint64_t* res,
int sample_len, int n,
int default_value) {
int64_t* node_index, int* actual_size,
uint64_t* res, int sample_len, int n,
int default_value) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ std::vector<uint64_t> GraphGpuWrapper::graph_neighbor_sample(
auto neighbor_sample_res =
((GpuPsGraphTable *)graph_table)
->graph_neighbor_sample_v2(gpu_id, idx, cuda_key, sample_size,
key.size(), false);
key.size(), false);
int *actual_sample_size = new int[key.size()];
cudaMemcpy(actual_sample_size, neighbor_sample_res.actual_sample_size,
key.size() * sizeof(int),
Expand Down
16 changes: 12 additions & 4 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,17 +336,25 @@ void BindGraphGpuWrapper(py::module* m) {
*m, "GraphGpuWrapper")
.def(py::init([]() { return GraphGpuWrapper::GetInstance(); }))
.def("neighbor_sample", &GraphGpuWrapper::graph_neighbor_sample_v3)
.def("graph_neighbor_sample", py::overload_cast<int, int64_t*, int, int>(&GraphGpuWrapper::graph_neighbor_sample))
.def("graph_neighbor_sample", py::overload_cast<int, int, std::vector<int64_t>&, int>(&GraphGpuWrapper::graph_neighbor_sample))
.def("graph_neighbor_sample",
py::overload_cast<int, int64_t*, int, int>(
&GraphGpuWrapper::graph_neighbor_sample))
.def("graph_neighbor_sample",
py::overload_cast<int, int, std::vector<int64_t>&, int>(
&GraphGpuWrapper::graph_neighbor_sample))
.def("set_device", &GraphGpuWrapper::set_device)
.def("set_feature_separator", &GraphGpuWrapper::set_feature_separator)
.def("init_service", &GraphGpuWrapper::init_service)
.def("set_up_types", &GraphGpuWrapper::set_up_types)
.def("query_node_list", &GraphGpuWrapper::query_node_list)
.def("add_table_feat_conf", &GraphGpuWrapper::add_table_feat_conf)
.def("load_edge_file", &GraphGpuWrapper::load_edge_file)
.def("upload_batch", py::overload_cast<int, std::vector<std::vector<int64_t>>&>(&GraphGpuWrapper::upload_batch))
.def("upload_batch", py::overload_cast<int, std::vector<std::vector<int64_t>>&, int>(&GraphGpuWrapper::upload_batch))
.def("upload_batch",
py::overload_cast<int, std::vector<std::vector<int64_t>>&>(
&GraphGpuWrapper::upload_batch))
.def("upload_batch",
py::overload_cast<int, std::vector<std::vector<int64_t>>&, int>(
&GraphGpuWrapper::upload_batch))
.def("get_all_id", &GraphGpuWrapper::get_all_id)
.def("load_next_partition", &GraphGpuWrapper::load_next_partition)
.def("make_partitions", &GraphGpuWrapper::make_partitions)
Expand Down

0 comments on commit c51530a

Please sign in to comment.