Skip to content

Commit

Permalink
Merge pull request #28 from DesmonDay/gpugraph
Browse files Browse the repository at this point in the history
Delete old sample interface
  • Loading branch information
seemingwang authored Jun 1, 2022
2 parents 453ddf5 + c51530a commit 1d313a8
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 304 deletions.
6 changes: 0 additions & 6 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
}
tables_ = std::vector<Table *>(
gpu_num * (graph_table_num + feature_table_num), NULL);
sample_status = std::vector<int *>(gpu_num * graph_table_num, NULL);
for (int i = 0; i < gpu_num; i++) {
global_device_map[resource_->dev_id(i)] = i;
for (int j = 0; j < graph_table_num; j++) {
Expand Down Expand Up @@ -123,13 +122,9 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
bool cpu_switch);
NeighborSampleResult graph_neighbor_sample(int gpu_id, uint64_t *key,
int sample_size, int len);
NeighborSampleResult graph_neighbor_sample(int gpu_id, int idx, uint64_t *key,
int sample_size, int len);
NeighborSampleResult graph_neighbor_sample_v2(int gpu_id, int idx,
uint64_t *key, int sample_size,
int len, bool cpu_query_switch);
void init_sample_status();
void free_sample_status();
NodeQueryResult query_node_list(int gpu_id, int idx, int start,
int query_size);
void display_sample_res(void *key, void *val, int len, int sample_len);
Expand All @@ -144,7 +139,6 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
std::vector<GpuPsCommGraph> gpu_graph_list_;
std::vector<GpuPsCommGraphFea> gpu_graph_fea_list_;
int global_device_map[32];
std::vector<int *> sample_status;
const int parallel_sample_size = 1;
const int dim_y = 256;
std::shared_ptr<paddle::distributed::GraphTable> cpu_graph_table_;
Expand Down
288 changes: 7 additions & 281 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,10 @@ __global__ void copy_buffer_ac_to_final_place(
}

template <int WARP_SIZE, int BLOCK_WARPS, int TILE_SIZE>
__global__ void neighbor_sample_example_v2(GpuPsCommGraph graph,
int64_t* node_index,
int* actual_size, uint64_t* res,
int sample_len, int n,
int default_value) {
__global__ void neighbor_sample_kernel(GpuPsCommGraph graph,
int64_t* node_index, int* actual_size,
uint64_t* res, int sample_len, int n,
int default_value) {
assert(blockDim.x == WARP_SIZE);
assert(blockDim.y == BLOCK_WARPS);

Expand Down Expand Up @@ -120,81 +119,6 @@ __global__ void neighbor_sample_example_v2(GpuPsCommGraph graph,
}
}

__global__ void neighbor_sample_example(GpuPsCommGraph graph,
int64_t* node_index, int* actual_size,
uint64_t* res, int sample_len,
int* sample_status, int n, int from) {
int id = blockIdx.x * blockDim.y + threadIdx.y;
if (id < n) {
if (node_index[id] == -1) {
actual_size[id] = 0;
return;
}
curandState rng;
curand_init(blockIdx.x, threadIdx.x, threadIdx.y, &rng);
int64_t index = threadIdx.x;
int64_t offset = id * sample_len;
uint64_t* data = graph.neighbor_list;
int64_t data_offset = graph.node_list[node_index[id]].neighbor_offset;
uint64_t neighbor_len = graph.node_list[node_index[id]].neighbor_size;
int ac_len;
if (sample_len > neighbor_len)
ac_len = neighbor_len;
else {
ac_len = sample_len;
}
if (4 * ac_len >= 3 * neighbor_len) {
if (index == 0) {
res[offset] = curand(&rng) % (neighbor_len - ac_len + 1);
}
__syncwarp();
int start = res[offset];
while (index < ac_len) {
res[offset + index] = data[data_offset + start + index];
index += blockDim.x;
}
actual_size[id] = ac_len;
} else {
while (index < ac_len) {
int num = curand(&rng) % neighbor_len;
int* addr = sample_status + data_offset + num;
int expected = *addr;
if (!(expected & (1 << from))) {
int old = atomicCAS(addr, expected, expected | (1 << from));
if (old == expected) {
res[offset + index] = num;
index += blockDim.x;
}
}
}
__syncwarp();
index = threadIdx.x;
while (index < ac_len) {
int* addr = sample_status + data_offset + res[offset + index];
int expected, old = *addr;
do {
expected = old;
old = atomicCAS(addr, expected, expected & (~(1 << from)));
} while (old != expected);
res[offset + index] = data[data_offset + res[offset + index]];
index += blockDim.x;
}
actual_size[id] = ac_len;
}
}
// const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
// if (i < n) {
// auto node_index = index[i];
// actual_size[i] = graph.node_list[node_index].neighbor_size < sample_size
// ? graph.node_list[node_index].neighbor_size
// : sample_size;
// int offset = graph.node_list[node_index].neighbor_offset;
// for (int j = 0; j < actual_size[i]; j++) {
// sample_result[sample_size * i + j] = graph.neighbor_list[offset + j];
// }
// }
}

int GpuPsGraphTable::init_cpu_table(
const paddle::distributed::GraphParameter& graph) {
cpu_graph_table_.reset(new paddle::distributed::GraphTable);
Expand Down Expand Up @@ -257,6 +181,7 @@ void GpuPsGraphTable::display_sample_res(void* key, void* val, int len,
printf("\n");
}
}

void GpuPsGraphTable::move_neighbor_sample_result_to_source_gpu(
int start_index, int gpu_num, int sample_size, int* h_left, int* h_right,
uint64_t* src_sample_res, int* actual_sample_size) {
Expand Down Expand Up @@ -397,7 +322,6 @@ void GpuPsGraphTable::build_graph_fea_on_single_gpu(GpuPsCommGraphFea& g,
int offset = gpu_id * feature_table_num_ + ntype_id;
gpu_graph_fea_list_[offset] = GpuPsCommGraphFea();

sample_status[offset] = NULL;
int table_offset =
get_table_offset(gpu_id, GraphTableType::FEATURE_TABLE, ntype_id);

Expand Down Expand Up @@ -473,7 +397,6 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i,
platform::CUDADeviceGuard guard(resource_->dev_id(i));
int offset = i * graph_table_num_ + idx;
gpu_graph_list_[offset] = GpuPsCommGraph();
sample_status[offset] = NULL;
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx);
size_t capacity = std::max((uint64_t)1, (uint64_t)g.node_size) / load_factor_;
tables_[table_offset] = new Table(capacity);
Expand Down Expand Up @@ -515,35 +438,6 @@ void GpuPsGraphTable::build_graph_on_single_gpu(GpuPsCommGraph& g, int i,
}
}

void GpuPsGraphTable::init_sample_status() {
for (int i = 0; i < gpu_num; i++) {
for (int j = 0; j < graph_table_num_; j++) {
int offset = i * graph_table_num_ + j;
if (gpu_graph_list_[offset].neighbor_size) {
platform::CUDADeviceGuard guard(resource_->dev_id(i));
int* addr;
cudaMalloc((void**)&addr,
gpu_graph_list_[offset].neighbor_size * sizeof(int));
cudaMemset(addr, 0,
gpu_graph_list_[offset].neighbor_size * sizeof(int));
sample_status[offset] = addr;
}
}
}
}

void GpuPsGraphTable::free_sample_status() {
for (int i = 0; i < gpu_num; i++) {
for (int j = 0; j < graph_table_num_; j++) {
int offset = i * graph_table_num_ + j;
if (sample_status[offset] != NULL) {
platform::CUDADeviceGuard guard(resource_->dev_id(i));
cudaFree(sample_status[offset]);
}
}
}
}

void GpuPsGraphTable::build_graph_fea_from_cpu(
std::vector<GpuPsCommGraphFea>& cpu_graph_fea_list, int ntype_id) {
PADDLE_ENFORCE_EQ(
Expand All @@ -557,7 +451,6 @@ void GpuPsGraphTable::build_graph_fea_from_cpu(
int offset = i * feature_table_num_ + ntype_id;
platform::CUDADeviceGuard guard(resource_->dev_id(i));
gpu_graph_fea_list_[offset] = GpuPsCommGraphFea();
sample_status[offset] = NULL;
tables_[table_offset] = new Table(
std::max((uint64_t)1, (uint64_t)cpu_graph_fea_list[i].node_size) /
load_factor_);
Expand Down Expand Up @@ -627,7 +520,6 @@ void GpuPsGraphTable::build_graph_from_cpu(
int offset = i * graph_table_num_ + idx;
platform::CUDADeviceGuard guard(resource_->dev_id(i));
gpu_graph_list_[offset] = GpuPsCommGraph();
sample_status[offset] = NULL;
tables_[table_offset] =
new Table(std::max((uint64_t)1, (uint64_t)cpu_graph_list[i].node_size) /
load_factor_);
Expand Down Expand Up @@ -679,173 +571,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id,
uint64_t* key,
int sample_size,
int len) {
return graph_neighbor_sample(gpu_id, 0, key, sample_size, len);
}

NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample(int gpu_id, int idx,
uint64_t* key,
int sample_size,
int len) {
/*
comment 2
this function shares some kernels with heter_comm_inl.h
arguments definitions:
gpu_id:the id of gpu.
len:how many keys are used,(the length of array key)
sample_size:how many neighbors should be sampled for each node in key.
the code below shuffle the key array to make the keys
that belong to a gpu-card stay together,
the shuffled result is saved on d_shard_keys,
if ith element in d_shard_keys_ptr is
from jth element in the original key array, then idx[i] = j,
idx could be used to recover the original array.
if keys in range [a,b] belong to ith-gpu, then h_left[i] = a, h_right[i] =
b,
if no keys are allocated for ith-gpu, then h_left[i] == h_right[i] == -1
for example, suppose key = [0,1,2,3,4,5,6,7,8], gpu_num = 2
when we run this neighbor_sample function,
the key is shuffled to [0,2,4,6,8,1,3,5,7]
the first part (0,2,4,6,8) % 2 == 0,thus should be handled by gpu 0,
the rest part should be handled by gpu1, because (1,3,5,7) % 2 == 1,
h_left = [0,5],h_right = [4,8]
*/

NeighborSampleResult result;
result.initialize(sample_size, len, resource_->dev_id(gpu_id));
if (len == 0) {
return result;
}
platform::CUDAPlace place = platform::CUDAPlace(resource_->dev_id(gpu_id));
platform::CUDADeviceGuard guard(resource_->dev_id(gpu_id));
int* actual_sample_size = result.actual_sample_size;
uint64_t* val = result.val;
int total_gpu = resource_->total_device();
auto stream = resource_->local_stream(gpu_id, 0);

int grid_size = (len - 1) / block_size_ + 1;

int h_left[total_gpu]; // NOLINT
int h_right[total_gpu]; // NOLINT

auto d_left = memory::Alloc(place, total_gpu * sizeof(int));
auto d_right = memory::Alloc(place, total_gpu * sizeof(int));
int* d_left_ptr = reinterpret_cast<int*>(d_left->ptr());
int* d_right_ptr = reinterpret_cast<int*>(d_right->ptr());

cudaMemsetAsync(d_left_ptr, -1, total_gpu * sizeof(int), stream);
cudaMemsetAsync(d_right_ptr, -1, total_gpu * sizeof(int), stream);
//
auto d_idx = memory::Alloc(place, len * sizeof(int));
int* d_idx_ptr = reinterpret_cast<int*>(d_idx->ptr());

auto d_shard_keys = memory::Alloc(place, len * sizeof(uint64_t));
uint64_t* d_shard_keys_ptr = reinterpret_cast<uint64_t*>(d_shard_keys->ptr());
auto d_shard_vals =
memory::Alloc(place, sample_size * len * sizeof(uint64_t));
uint64_t* d_shard_vals_ptr = reinterpret_cast<uint64_t*>(d_shard_vals->ptr());
auto d_shard_actual_sample_size = memory::Alloc(place, len * sizeof(int));
int* d_shard_actual_sample_size_ptr =
reinterpret_cast<int*>(d_shard_actual_sample_size->ptr());

split_input_to_shard((uint64_t*)(key), d_idx_ptr, len, d_left_ptr,
d_right_ptr, gpu_id);

heter_comm_kernel_->fill_shard_key(d_shard_keys_ptr, key, d_idx_ptr, len,
stream);
cudaStreamSynchronize(stream);

cudaMemcpy(h_left, d_left_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
cudaMemcpy(h_right, d_right_ptr, total_gpu * sizeof(int),
cudaMemcpyDeviceToHost);
// auto start1 = std::chrono::steady_clock::now();
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
/*
comment 3
shard_len denotes the size of keys on i-th gpu here,
when we sample on i-th gpu, we allocate shard_len * (1 + sample_size)
int64_t units
of memory, we use alloc_mem_i to denote it, the range [0,shard_len) is saved
for the respective nodes' indexes
and acutal sample_size.
with nodes' indexes we could get the nodes to sample.
since size of int64_t is 8 bits, while size of int is 4,
the range of [0,shard_len) contains shard_len * 2 int uinits;
The values of the first half of this range will be updated by
the k-v map on i-th-gpu.
The second half of this range is saved for actual sample size of each node.
For node x,
its sampling result is saved on the range
[shard_len + sample_size * x,shard_len + sample_size * x +
actual_sample_size_of_x)
of alloc_mem_i, actual_sample_size_of_x equals ((int
*)alloc_mem_i)[shard_len + x]
*/

create_storage(gpu_id, i, shard_len * sizeof(int64_t),
shard_len * (1 + sample_size) * sizeof(int64_t) +
sizeof(int) * (shard_len + shard_len % 2));
// auto& node = path_[gpu_id][i].nodes_[0];
}
walk_to_dest(gpu_id, total_gpu, h_left, h_right,
(uint64_t*)(d_shard_keys_ptr), NULL);

for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
auto& node = path_[gpu_id][i].nodes_.back();
cudaMemsetAsync(node.val_storage, -1, shard_len * sizeof(int64_t),
node.in_stream);
cudaStreamSynchronize(node.in_stream);
platform::CUDADeviceGuard guard(resource_->dev_id(i));
int offset = i * graph_table_num_ + idx;
int table_offset = get_table_offset(i, GraphTableType::EDGE_TABLE, idx);
tables_[table_offset]->get(reinterpret_cast<uint64_t*>(node.key_storage),
reinterpret_cast<int64_t*>(node.val_storage),
h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, gpu_id));
// node.in_stream);
auto graph = gpu_graph_list_[offset];
int64_t* id_array = reinterpret_cast<int64_t*>(node.val_storage);
int* actual_size_array = (int*)(id_array + shard_len);
uint64_t* sample_array =
(uint64_t*)(actual_size_array + shard_len + shard_len % 2);
int sample_grid_size = (shard_len - 1) / dim_y + 1;
dim3 block(parallel_sample_size, dim_y);
dim3 grid(sample_grid_size);
neighbor_sample_example<<<grid, block, 0,
resource_->remote_stream(i, gpu_id)>>>(
graph, id_array, actual_size_array, sample_array, sample_size,
sample_status[offset], shard_len, gpu_id);
}
for (int i = 0; i < total_gpu; ++i) {
if (h_left[i] == -1) {
continue;
}
cudaStreamSynchronize(resource_->remote_stream(i, gpu_id));
}
move_neighbor_sample_result_to_source_gpu(gpu_id, total_gpu, sample_size,
h_left, h_right, d_shard_vals_ptr,
d_shard_actual_sample_size_ptr);
fill_dvalues<<<grid_size, block_size_, 0, stream>>>(
d_shard_vals_ptr, val, d_shard_actual_sample_size_ptr, actual_sample_size,
d_idx_ptr, sample_size, len);
for (int i = 0; i < total_gpu; ++i) {
int shard_len = h_left[i] == -1 ? 0 : h_right[i] - h_left[i] + 1;
if (shard_len == 0) {
continue;
}
destroy_storage(gpu_id, i);
}
cudaStreamSynchronize(stream);
return result;
return graph_neighbor_sample_v2(gpu_id, 0, key, sample_size, len, false);
}

NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
Expand Down Expand Up @@ -947,7 +673,7 @@ NeighborSampleResult GpuPsGraphTable::graph_neighbor_sample_v2(
constexpr int TILE_SIZE = BLOCK_WARPS * 16;
const dim3 block(WARP_SIZE, BLOCK_WARPS);
const dim3 grid((shard_len + TILE_SIZE - 1) / TILE_SIZE);
neighbor_sample_example_v2<
neighbor_sample_kernel<
WARP_SIZE, BLOCK_WARPS,
TILE_SIZE><<<grid, block, 0, resource_->remote_stream(i, gpu_id)>>>(
graph, id_array, actual_size_array, sample_array, sample_size,
Expand Down
Loading

0 comments on commit 1d313a8

Please sign in to comment.