Skip to content

Commit

Permalink
add whether load node and edge parallel flag (#28)
Browse files Browse the repository at this point in the history
* add whether load node and edge parallel flag

* add whether load node and edge parallel flag

* add whether load node and edge parallel flag

Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0009.yq01.baidu.com>
  • Loading branch information
miaoli06 and root committed Jun 13, 2022
1 parent a7ce0cf commit d6f1800
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 99 deletions.
252 changes: 155 additions & 97 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
#include <chrono>
#include <set>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"

DECLARE_bool(graph_load_in_parallel);

namespace paddle {
namespace distributed {

Expand Down Expand Up @@ -1058,6 +1061,45 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
return 0;
}

int32_t GraphTable::parse_node_file(const std::string &path, const std::string &node_type, int idx, uint64_t &count, uint64_t &valid_count) {
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
if (values[0] != node_type) {
continue;
}

auto id = std::stoull(values[1]);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}
local_count++;

size_t index = shard_id - shard_start;
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
node->set_feature_size(feat_name[idx].size());
for (size_t slice = 2; slice < values.size(); slice++) {
parse_feature(idx, values[slice], node);
}
}
local_valid_count++;
}
mutex_.lock();
count += local_count;
valid_count += local_valid_count;
mutex_.unlock();
VLOG(0) << "node_type[" << node_type << "] loads " << local_count << " nodes from filepath->" << path;
return 0;
}

// TODO opt load all node_types in once reading
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";");
Expand All @@ -1077,45 +1119,25 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
}

VLOG(0) << "Begin GraphTable::load_nodes() node_type[" << node_type << "]";
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool[i % load_thread_num]->enqueue(
[&, i, idx, this]() -> int {
VLOG(0) << "Begin GraphTable::load_nodes(), path[" << paths[i] << "]";
std::ifstream file(paths[i]);
std::string line;
uint64_t local_count = 0;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
if (values[0] != node_type) {
continue;
}

auto id = std::stoull(values[1]);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}

size_t index = shard_id - shard_start;
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
node->set_feature_size(feat_name[idx].size());
for (size_t slice = 2; slice < values.size(); slice++) {
parse_feature(idx, values[slice], node);
}
local_count++;
}
}
VLOG(0) << "node_type[" << node_type << "] loads " << local_count << " nodes from filepath->" << paths[i];
return 0;
}));
if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> int {
parse_node_file(paths[i], node_type, idx, count, valid_count);
return 0;
}));
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
} else {
for (auto path : paths) {
VLOG(2) << "Begin GraphTable::load_nodes(), path[" << path << "]";
parse_node_file(path, node_type, idx, count, valid_count);
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
VLOG(0) << "successfully load all node_type[" << node_type << "] data";

VLOG(0) << valid_count << "/" << count << " nodes in node_type[ " << node_type
<< "] are loaded successfully!";
return 0;
}

Expand All @@ -1128,11 +1150,77 @@ int32_t GraphTable::build_sampler(int idx, std::string sample_type) {
}
return 0;
}

int32_t GraphTable::parse_edge_file(const std::string &path, int idx, bool reverse, uint64_t &count, uint64_t &valid_count) {
std::string sample_type = "random";
bool is_weighted = false;
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
uint64_t part_num = 0;
if (FLAGS_graph_load_in_parallel) {
auto path_split = paddle::string::split_string<std::string>(path, "/");
auto part_name_split = paddle::string::split_string<std::string>(path_split[path_split.size() - 1], "-");
part_num = std::stoull(part_name_split[part_name_split.size() - 1]);
}

while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
local_count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
auto dst_id = std::stoull(values[1]);
if (reverse) {
std::swap(src_id, dst_id);
}
size_t src_shard_id = src_id % shard_num;
if (FLAGS_graph_load_in_parallel) {
if (src_shard_id != (part_num % shard_num)) {
continue;
}
}

float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
is_weighted = true;
}

if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}

size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(is_weighted);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, weight);
local_valid_count++;
}
mutex_.lock();
count += local_count;
valid_count += local_valid_count;
#ifdef PADDLE_WITH_HETERPS
const uint64_t fixed_load_edges = 1000000;
if (count > fixed_load_edges && search_level == 2) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
#endif
mutex_.unlock();
VLOG(0) << local_count << " edges are loaded from filepath->" << path;
return 0;
}

int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,
const std::string &edge_type) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) total_memory_cost = 0;
const uint64_t fixed_load_edges = 1000000;
//const uint64_t fixed_load_edges = 1000000;
#endif
int idx = 0;
if (edge_type == "") {
Expand All @@ -1149,65 +1237,37 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,

auto paths = paddle::string::split_string<std::string>(path, ";");
uint64_t count = 0;
std::string sample_type = "random";
bool is_weighted = false;
uint64_t valid_count = 0;

VLOG(0) << "Begin GraphTable::load_edges() edge_type[" << edge_type << "]";
std::vector<std::future<int>> tasks;
for (int i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool[i % load_thread_num]->enqueue(
[&, i, idx, this]() -> int {
uint64_t local_count = 0;
std::ifstream file(paths[i]);
std::string line;
auto path_split = paddle::string::split_string<std::string>(paths[i], "/");
auto part_name_split = paddle::string::split_string<std::string>(path_split[path_split.size() - 1], "-");
auto part_num = std::stoull(part_name_split[part_name_split.size() - 1]);

while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
local_count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
auto dst_id = std::stoull(values[1]);
if (reverse_edge) {
std::swap(src_id, dst_id);
}
size_t src_shard_id = src_id % shard_num;
if (src_shard_id != (part_num % shard_num)) {
continue;
}

float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
is_weighted = true;
}
if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<int>> tasks;
for (int i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> int {
parse_edge_file(paths[i], idx, reverse_edge, count, valid_count);
return 0;
}));
}
for (int j = 0; j < (int)tasks.size(); j++) tasks[j].get();
} else {
for (auto path : paths) {
parse_edge_file(path, idx, reverse_edge, count, valid_count);
}
}
VLOG(0) << valid_count << "/" << count << " edge_type[" << edge_type << "] edges are loaded successfully";

if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}

size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(is_weighted);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, weight);
}
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
}
#endif
VLOG(0) << local_count << " edges are loaded from filepath->" << paths[i];
return 0;
}));
if (search_level == 2) {
if (count > 0) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
return 0;
}
for (int j = 0; j < (int)tasks.size(); j++) tasks[j].get();
VLOG(0) << "successfully load all edge_type[" << edge_type << "] data";
#endif

#ifdef PADDLE_WITH_GPU_GRAPH
// To reduce memory overhead, CPU samplers won't be created in gpugraph.
Expand Down Expand Up @@ -1751,10 +1811,8 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
_shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0));
}
load_node_edge_task_pool.resize(load_thread_num);
for (size_t i = 0; i< load_node_edge_task_pool.size(); i++) {
load_node_edge_task_pool[i].reset(new ::ThreadPool(1));
}
load_node_edge_task_pool.reset(new ::ThreadPool(load_thread_num));

auto graph_feature = graph.graph_feature();
auto node_types = graph.node_types();
auto edge_types = graph.edge_types();
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,10 @@ class GraphTable : public Table {
int get_all_feature_ids(int type, int idx,
int slice_num, std::vector<std::vector<uint64_t>>* output);
int32_t load_nodes(const std::string &path, std::string node_type);

int32_t parse_edge_file(const std::string &path, int idx, bool reverse,
uint64_t &count, uint64_t &valid_count);
int32_t parse_node_file(const std::string &path, const std::string &node_type,
int idx, uint64_t &count, uint64_t &valid_count);
int32_t add_graph_node(int idx, std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list);

Expand Down Expand Up @@ -617,7 +620,7 @@ class GraphTable : public Table {

std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::vector<std::shared_ptr<::ThreadPool>> load_node_edge_task_pool;
std::shared_ptr<::ThreadPool> load_node_edge_task_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<uint64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,19 @@ PADDLE_DEFINE_EXPORTED_bool(
apply_pass_to_program, false,
"It controls whether to apply IR pass to program when using Fleet APIs");

/**
* Distributed related FLAG
* Name: FLAGS_graph_load_in_parallel
* Since Version: 2.2.0
* Value Range: bool, default=false
* Example:
* Note: Control whether load graph node and edge with multi threads parallely
* If it is not set, load graph data with one thread
*/
PADDLE_DEFINE_EXPORTED_bool(
graph_load_in_parallel, false,
"It controls whether load graph node and edge with mutli threads parallely.");

/**
* KP kernel related FLAG
* Name: FLAGS_run_kp_kernel
Expand Down

0 comments on commit d6f1800

Please sign in to comment.