Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SbpEdge #8684

Merged
merged 3 commits into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions oneflow/core/auto_parallel/sbp_collector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ void SbpCollector::InitializeCopyCostFromNode2Proxy(SbpNode<NdSbpSignature>* sbp
const LogicalBlobId& lbi) {
// the only edge from producer to proxy of producer
SbpEdge<NdSbpSignature>* sbp_edge = sbp_proxy->EdgesIn[0];
SbpNode<NdSbpSignature>* sbp_node_producer = sbp_edge->StartNode;
sbp_edge->Cost.resize(sbp_node_producer->SbpSignatureList.size());
SbpNode<NdSbpSignature>* sbp_node_producer = sbp_edge->start_node_;
sbp_edge->cost_.resize(sbp_node_producer->SbpSignatureList.size());
int32_t consumer_sbp_size = sbp_proxy->ParallelCandidates.size();
// look through sbp signature in producer
for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_node_producer->SbpSignatureList.size();
sbp_id_producer++) {
sbp_edge->Cost[sbp_id_producer].resize(consumer_sbp_size, 0);
sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0);
}

// Assemble copy cost from producer to proxy of producer
Expand Down Expand Up @@ -173,7 +173,7 @@ void SbpCollector::InitializeCopyCostFromNode2Proxy(SbpNode<NdSbpSignature>* sbp

// compute copy cost for a specific logical blob
// Use the parallel description of producer as those for consumer for now.
sbp_edge->Cost[sbp_id_producer][sbp_id_consumer] +=
sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] +=
CHECK_JUST(ComputeCopyCostWithMiddleNodes(sbp_producer, sbp_consumer, logical_blob_desc,
producer_parallel_desc,
producer_parallel_desc, /*is_same=*/false));
Expand Down Expand Up @@ -201,15 +201,15 @@ void SbpCollector::InitializeCopyCostFromProxy2Consumer(
// Connect sbp proxy and consumer
sbp_proxy->PointTo(sbp_node_consumer);
// the sbp edge connecting proxy and consumer
SbpEdge<NdSbpSignature>* sbp_edge = FindEdgeBetweenNodes(sbp_proxy, sbp_node_consumer);
sbp_edge->Cost.resize(sbp_proxy->ParallelCandidates.size());
SbpEdge<NdSbpSignature>* sbp_edge = sbp_node_consumer->FindEdgeWithNode(sbp_proxy);
sbp_edge->cost_.resize(sbp_proxy->ParallelCandidates.size());
int32_t consumer_sbp_size = sbp_node_consumer->SbpSignatureList.size();

// look through sbp parallel set in proxy
for (int32_t sbp_id_producer = 0; sbp_id_producer < sbp_proxy->ParallelCandidates.size();
sbp_id_producer++) {
// initialization for copy cost
sbp_edge->Cost[sbp_id_producer].resize(consumer_sbp_size, 0);
sbp_edge->cost_[sbp_id_producer].resize(consumer_sbp_size, 0);
// get sbp parallel set for a logical blob in proxy
BinarySet& parallel_candidate = sbp_proxy->ParallelCandidates[sbp_id_producer];

Expand All @@ -221,7 +221,7 @@ void SbpCollector::InitializeCopyCostFromProxy2Consumer(
const NdSbp& sbp_consumer = consumer_sbp_bn_in_op2sbp_parallel.at(ibn);

if ((!parallel_candidate.CheckExistency(SbpParallelUniverse[sbp_consumer]))) {
sbp_edge->Cost[sbp_id_producer][sbp_id_consumer] = GetMaxVal<float>();
sbp_edge->cost_[sbp_id_producer][sbp_id_consumer] = GetMaxVal<float>();
}
}
}
Expand Down Expand Up @@ -351,15 +351,14 @@ void SbpCollector::ProxySbpCandidate(
// consumer in cost model
SbpNode<NdSbpSignature>* sbp_node_consumer = op_name2sbp_node[consumer_bn_group.first.first];
// the sbp edge connecting producer and consumer
SbpEdge<NdSbpSignature>* edge_found =
FindEdgeBetweenNodes(sbp_node_producer, sbp_node_consumer);
SbpEdge<NdSbpSignature>* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer);
// unload logical blob from sbp edges
edge_found->UnloadLbi(lbi);
// Do not clip this edge. Save it for wait time.
// clip this edge if it no longer carries any blob
// We don't clip edges before since we have transfer cost
// Now we clip edges, which makes the topology simplier
if (edge_found->EmptyLbi() && edge_found->WaitTime <= 0.0 && edge_found->WaitTime > -0.5
if (edge_found->EmptyLbi() && edge_found->wait_time_ <= 0.0 && edge_found->wait_time_ > -0.5
&& sbp_graph.transfer_cost <= 0.0) {
sbp_graph.ClipEdge(edge_found);
}
Expand Down
16 changes: 8 additions & 8 deletions oneflow/core/auto_parallel/sbp_constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,27 +246,27 @@ Maybe<void> SbpConstructor::InitCopyCost(const OpGraph& op_graph) {
// Initialize copy cost between two nodes
for (auto* sbp_edge : sbp_node_consumer->EdgesIn) {
// producer sbp node
const auto* sbp_node_producer = sbp_edge->StartNode;
const auto* sbp_node_producer = sbp_edge->start_node_;
// skip it if proxy
if (!sbp_node_producer->op_node) { continue; }
sbp_edge->Cost.resize(sbp_node_producer->SbpSignatureList.size());
sbp_edge->cost_.resize(sbp_node_producer->SbpSignatureList.size());
int32_t consumer_sbp_size = sbp_node_consumer->SbpSignatureList.size();
// look through sbp signature in producer
for (int32_t i = 0; i < sbp_node_producer->SbpSignatureList.size(); ++i) {
sbp_edge->Cost[i].resize(consumer_sbp_size, 0);
sbp_edge->cost_[i].resize(consumer_sbp_size, 0);
}
}
// Find all those cases with wait time
// Do not skip edges carrying no lbi
sbp_node_consumer->InitializeCopyCost(false, use_sbp_collector_);
for (auto* sbp_edge : sbp_node_consumer->EdgesIn) {
// skip it if proxy
if (!sbp_edge->StartNode->op_node) { continue; }
if (!sbp_edge->start_node_->op_node) { continue; }
// Reset Wait time
for (int32_t i = 0; i < sbp_edge->Cost.size(); ++i) {
for (int32_t j = 0; j < sbp_edge->Cost[i].size(); ++j) {
for (int32_t i = 0; i < sbp_edge->cost_.size(); ++i) {
for (int32_t j = 0; j < sbp_edge->cost_[i].size(); ++j) {
// If transferring between devices, we need to add wait time.
if (sbp_edge->Cost[i][j] > 0.0) { sbp_edge->Cost[i][j] = sbp_edge->WaitTime; }
if (sbp_edge->cost_[i][j] > 0.0) { sbp_edge->cost_[i][j] = sbp_edge->wait_time_; }
}
}
}
Expand Down Expand Up @@ -301,7 +301,7 @@ void SbpConstructor::LoadLbi2SbpEdge(const OpGraph& op_graph) {
// producer sbp node
const auto* sbp_node_producer = op_name2sbp_node_[producer->op().op_name()];
// TODO: recode this
auto* edge_found = auto_parallel::FindEdgeBetweenNodes(sbp_node_producer, sbp_node_consumer);
auto* edge_found = sbp_node_consumer->FindEdgeWithNode(sbp_node_producer);

CHECK(edge_found != NULL) << "SbpEdge not found while loading!" << std::endl;

Expand Down
Loading