Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add whether load node and edge parallel flag #28

Merged
merged 3 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 155 additions & 97 deletions paddle/fluid/distributed/ps/table/common_graph_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
#include <chrono>
#include <set>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/fluid/distributed/common/utils.h"
#include "paddle/fluid/distributed/ps/table/graph/graph_node.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/fluid/string/string_helper.h"

DECLARE_bool(graph_load_in_parallel);

namespace paddle {
namespace distributed {

Expand Down Expand Up @@ -1058,6 +1061,45 @@ int32_t GraphTable::get_nodes_ids_by_ranges(
return 0;
}

int32_t GraphTable::parse_node_file(const std::string &path, const std::string &node_type, int idx, uint64_t &count, uint64_t &valid_count) {
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
if (values[0] != node_type) {
continue;
}

auto id = std::stoull(values[1]);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}
local_count++;

size_t index = shard_id - shard_start;
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
node->set_feature_size(feat_name[idx].size());
for (size_t slice = 2; slice < values.size(); slice++) {
parse_feature(idx, values[slice], node);
}
}
local_valid_count++;
}
mutex_.lock();
count += local_count;
valid_count += local_valid_count;
mutex_.unlock();
VLOG(0) << "node_type[" << node_type << "] loads " << local_count << " nodes from filepath->" << path;
return 0;
}

// TODO opt load all node_types in once reading
int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
auto paths = paddle::string::split_string<std::string>(path, ";");
Expand All @@ -1077,45 +1119,25 @@ int32_t GraphTable::load_nodes(const std::string &path, std::string node_type) {
}

VLOG(0) << "Begin GraphTable::load_nodes() node_type[" << node_type << "]";
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool[i % load_thread_num]->enqueue(
[&, i, idx, this]() -> int {
VLOG(0) << "Begin GraphTable::load_nodes(), path[" << paths[i] << "]";
std::ifstream file(paths[i]);
std::string line;
uint64_t local_count = 0;
while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
if (values.size() < 2) continue;
if (values[0] != node_type) {
continue;
}

auto id = std::stoull(values[1]);
size_t shard_id = id % shard_num;
if (shard_id >= shard_end || shard_id < shard_start) {
VLOG(4) << "will not load " << id << " from " << path
<< ", please check id distribution";
continue;
}

size_t index = shard_id - shard_start;
auto node = feature_shards[idx][index]->add_feature_node(id, false);
if (node != NULL) {
node->set_feature_size(feat_name[idx].size());
for (size_t slice = 2; slice < values.size(); slice++) {
parse_feature(idx, values[slice], node);
}
local_count++;
}
}
VLOG(0) << "node_type[" << node_type << "] loads " << local_count << " nodes from filepath->" << paths[i];
return 0;
}));
if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<int>> tasks;
for (size_t i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> int {
parse_node_file(paths[i], node_type, idx, count, valid_count);
return 0;
}));
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
} else {
for (auto path : paths) {
VLOG(2) << "Begin GraphTable::load_nodes(), path[" << path << "]";
parse_node_file(path, node_type, idx, count, valid_count);
}
}
for (int i = 0; i < (int)tasks.size(); i++) tasks[i].get();
VLOG(0) << "successfully load all node_type[" << node_type << "] data";

VLOG(0) << valid_count << "/" << count << " nodes in node_type[ " << node_type
<< "] are loaded successfully!";
return 0;
}

Expand All @@ -1128,11 +1150,77 @@ int32_t GraphTable::build_sampler(int idx, std::string sample_type) {
}
return 0;
}

int32_t GraphTable::parse_edge_file(const std::string &path, int idx, bool reverse, uint64_t &count, uint64_t &valid_count) {
std::string sample_type = "random";
bool is_weighted = false;
std::ifstream file(path);
std::string line;
uint64_t local_count = 0;
uint64_t local_valid_count = 0;
uint64_t part_num = 0;
if (FLAGS_graph_load_in_parallel) {
auto path_split = paddle::string::split_string<std::string>(path, "/");
auto part_name_split = paddle::string::split_string<std::string>(path_split[path_split.size() - 1], "-");
part_num = std::stoull(part_name_split[part_name_split.size() - 1]);
}

while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
local_count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
auto dst_id = std::stoull(values[1]);
if (reverse) {
std::swap(src_id, dst_id);
}
size_t src_shard_id = src_id % shard_num;
if (FLAGS_graph_load_in_parallel) {
if (src_shard_id != (part_num % shard_num)) {
continue;
}
}

float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
is_weighted = true;
}

if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}

size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(is_weighted);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, weight);
local_valid_count++;
}
mutex_.lock();
count += local_count;
valid_count += local_valid_count;
#ifdef PADDLE_WITH_HETERPS
const uint64_t fixed_load_edges = 1000000;
if (count > fixed_load_edges && search_level == 2) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
#endif
mutex_.unlock();
VLOG(0) << local_count << " edges are loaded from filepath->" << path;
return 0;
}

int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,
const std::string &edge_type) {
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) total_memory_cost = 0;
const uint64_t fixed_load_edges = 1000000;
//const uint64_t fixed_load_edges = 1000000;
#endif
int idx = 0;
if (edge_type == "") {
Expand All @@ -1149,65 +1237,37 @@ int32_t GraphTable::load_edges(const std::string &path, bool reverse_edge,

auto paths = paddle::string::split_string<std::string>(path, ";");
uint64_t count = 0;
std::string sample_type = "random";
bool is_weighted = false;
uint64_t valid_count = 0;

VLOG(0) << "Begin GraphTable::load_edges() edge_type[" << edge_type << "]";
std::vector<std::future<int>> tasks;
for (int i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool[i % load_thread_num]->enqueue(
[&, i, idx, this]() -> int {
uint64_t local_count = 0;
std::ifstream file(paths[i]);
std::string line;
auto path_split = paddle::string::split_string<std::string>(paths[i], "/");
auto part_name_split = paddle::string::split_string<std::string>(path_split[path_split.size() - 1], "-");
auto part_num = std::stoull(part_name_split[part_name_split.size() - 1]);

while (std::getline(file, line)) {
auto values = paddle::string::split_string<std::string>(line, "\t");
local_count++;
if (values.size() < 2) continue;
auto src_id = std::stoull(values[0]);
auto dst_id = std::stoull(values[1]);
if (reverse_edge) {
std::swap(src_id, dst_id);
}
size_t src_shard_id = src_id % shard_num;
if (src_shard_id != (part_num % shard_num)) {
continue;
}

float weight = 1;
if (values.size() == 3) {
weight = std::stof(values[2]);
sample_type = "weighted";
is_weighted = true;
}
if (FLAGS_graph_load_in_parallel) {
std::vector<std::future<int>> tasks;
for (int i = 0; i < paths.size(); i++) {
tasks.push_back(load_node_edge_task_pool->enqueue(
[&, i, idx, this]() -> int {
parse_edge_file(paths[i], idx, reverse_edge, count, valid_count);
return 0;
}));
}
for (int j = 0; j < (int)tasks.size(); j++) tasks[j].get();
} else {
for (auto path : paths) {
parse_edge_file(path, idx, reverse_edge, count, valid_count);
}
}
VLOG(0) << valid_count << "/" << count << " edge_type[" << edge_type << "] edges are loaded successfully";

if (src_shard_id >= shard_end || src_shard_id < shard_start) {
VLOG(4) << "will not load " << src_id << " from " << path
<< ", please check id distribution";
continue;
}

size_t index = src_shard_id - shard_start;
edge_shards[idx][index]->add_graph_node(src_id)->build_edges(is_weighted);
edge_shards[idx][index]->add_neighbor(src_id, dst_id, weight);
}
#ifdef PADDLE_WITH_HETERPS
if (search_level == 2) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
}
#endif
VLOG(0) << local_count << " edges are loaded from filepath->" << paths[i];
return 0;
}));
if (search_level == 2) {
if (count > 0) {
dump_edges_to_ssd(idx);
VLOG(0) << "dumping edges to ssd, edge count is reset to 0";
clear_graph(idx);
count = 0;
}
return 0;
}
for (int j = 0; j < (int)tasks.size(); j++) tasks[j].get();
VLOG(0) << "successfully load all edge_type[" << edge_type << "] data";
#endif

#ifdef PADDLE_WITH_GPU_GRAPH
// To reduce memory overhead, CPU samplers won't be created in gpugraph.
Expand Down Expand Up @@ -1751,10 +1811,8 @@ int32_t GraphTable::Initialize(const GraphParameter &graph) {
_shards_task_pool[i].reset(new ::ThreadPool(1));
_shards_task_rng_pool.push_back(paddle::framework::GetCPURandomEngine(0));
}
load_node_edge_task_pool.resize(load_thread_num);
for (size_t i = 0; i< load_node_edge_task_pool.size(); i++) {
load_node_edge_task_pool[i].reset(new ::ThreadPool(1));
}
load_node_edge_task_pool.reset(new ::ThreadPool(load_thread_num));

auto graph_feature = graph.graph_feature();
auto node_types = graph.node_types();
auto edge_types = graph.edge_types();
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/distributed/ps/table/common_graph_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,10 @@ class GraphTable : public Table {
int get_all_feature_ids(int type, int idx,
int slice_num, std::vector<std::vector<uint64_t>>* output);
int32_t load_nodes(const std::string &path, std::string node_type);

int32_t parse_edge_file(const std::string &path, int idx, bool reverse,
uint64_t &count, uint64_t &valid_count);
int32_t parse_node_file(const std::string &path, const std::string &node_type,
int idx, uint64_t &count, uint64_t &valid_count);
int32_t add_graph_node(int idx, std::vector<uint64_t> &id_list,
std::vector<bool> &is_weight_list);

Expand Down Expand Up @@ -617,7 +620,7 @@ class GraphTable : public Table {

std::vector<std::shared_ptr<::ThreadPool>> _shards_task_pool;
std::vector<std::shared_ptr<std::mt19937_64>> _shards_task_rng_pool;
std::vector<std::shared_ptr<::ThreadPool>> load_node_edge_task_pool;
std::shared_ptr<::ThreadPool> load_node_edge_task_pool;
std::shared_ptr<ScaledLRU<SampleKey, SampleResult>> scaled_lru;
std::unordered_set<uint64_t> extra_nodes;
std::unordered_map<uint64_t, size_t> extra_nodes_to_thread_index;
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,19 @@ PADDLE_DEFINE_EXPORTED_bool(
apply_pass_to_program, false,
"It controls whether to apply IR pass to program when using Fleet APIs");

/**
* Distributed related FLAG
* Name: FLAGS_graph_load_in_parallel
* Since Version: 2.2.0
* Value Range: bool, default=false
* Example:
* Note: Control whether load graph node and edge with multi threads parallely
* If it is not set, load graph data with one thread
*/
PADDLE_DEFINE_EXPORTED_bool(
graph_load_in_parallel, false,
"It controls whether load graph node and edge with mutli threads parallely.");

/**
* KP kernel related FLAG
* Name: FLAGS_run_kp_kernel
Expand Down