Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add data dependency graph #70485

Merged
merged 4 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refine code
  • Loading branch information
Dmovic committed Dec 31, 2024
commit d96ba6af5cd56b388ee4edcb710f0206dae7ea29
40 changes: 21 additions & 19 deletions paddle/cinn/ir/ir_analyzer/data_dependency_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <vector>

#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/data_dependency_graph.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"

Expand Down Expand Up @@ -52,7 +50,7 @@ class MemRefCollector : public ir::stmt::StmtVisitor<>,

void Visit(const ir::Load* op, const ir::Expr* expr) override {
auto tensor_node = op->tensor.As<ir::_Tensor_>();
loads_.insert({tensor_node->buffer->name, tensor_node->buffer});
loads_.insert({tensor_node->buffer->name});
ir::IRMutator<const ir::Expr*>::Visit(op, expr);
}

Expand All @@ -65,18 +63,18 @@ class MemRefCollector : public ir::stmt::StmtVisitor<>,
"threadIdx.z"};
if (op->is_symbolic_constant) return;
if (gpu_axis.count(op->name)) return;
loads_.insert({op->name, op->Copy()});
loads_.insert({op->name});
ir::IRMutator<const ir::Expr*>::Visit(op, expr);
}

void Visit(const ir::Call* op, const ir::Expr* expr) override {
for (auto write_arg : op->write_args) {
if (write_arg.As<ir::_Var_>()) {
stores_.insert({write_arg.As<ir::_Var_>()->name, write_arg});
stores_.insert({write_arg.As<ir::_Var_>()->name});
} else if (write_arg.As<ir::Load>()) {
auto load_node = write_arg.As<ir::Load>();
auto tensor_node = load_node->tensor.As<ir::_Tensor_>();
stores_.insert({tensor_node->buffer->name, tensor_node->buffer});
stores_.insert({tensor_node->buffer->name});
} else {
VLOG(6) << "Not support type in write arguments: \n" << write_arg;
}
Expand All @@ -90,14 +88,14 @@ class MemRefCollector : public ir::stmt::StmtVisitor<>,

void MemRefCollector::VisitStmt(const ir::stmt::Let& stmt) {
if (stmt->symbol().As<ir::_Var_>())
stores_.insert({stmt->symbol().As<ir::_Var_>()->name, stmt->symbol()});
stores_.insert({stmt->symbol().As<ir::_Var_>()->name});
ir::IRMutator<const ir::Expr*>::Visit(&stmt->body(), &stmt->body());
}

void MemRefCollector::VisitStmt(const ir::stmt::Store& stmt) {
auto tensor_node = stmt->tensor().As<ir::_Tensor_>();
if (tensor_node->buffer.get()) {
stores_.insert({tensor_node->buffer->name, tensor_node->buffer});
stores_.insert({tensor_node->buffer->name});
}
ir::IRMutator<const ir::Expr*>::Visit(&stmt->value(), &stmt->value());
for (std::size_t i = 0; i < stmt->indices().size(); i++) {
Expand Down Expand Up @@ -130,6 +128,14 @@ void MemRefCollector::VisitStmt(const ir::stmt::Evaluate& stmt) {
void MemRefCollector::VisitStmt(const ir::stmt::Alloc& stmt) {}
void MemRefCollector::VisitStmt(const ir::stmt::Free& stmt) {}

DataDependencyGraph::Node::Node(unsigned id, const ir::stmt::StmtRef& stmt)
: id(id), stmt(stmt) {
MemRefCollector collector;
collector.VisitStmt(stmt);
loads = collector.GetLoads();
stores = collector.GetStores();
}

DepKind DataDependencyGraph::HasDependency(const ir::stmt::StmtRef& src,
const ir::stmt::StmtRef& dst) const {
// Run BFS traversal to check if src and dst are reachable.
Expand All @@ -141,8 +147,8 @@ DepKind DataDependencyGraph::HasDependency(const ir::stmt::StmtRef& src,
while (!queue.empty()) {
auto id = queue.front();
if (id == dst_id) return DepKind::DEP;
// If node has no out edges, or have been visited already, record node and
// continue.
// If node has no out edges, or have been visited already,
// record node and continue.
if (out_edges_.count(id) == 0 || visited.count(id) != 0) {
queue.pop_front();
visited.insert(id);
Expand Down Expand Up @@ -175,10 +181,6 @@ void DataDependencyGraph::BuildGraphByStmts() {
auto BuildNodes = [&]() {
for (auto& stmt : stmts_) {
Node node(next_node_id_++, stmt);
MemRefCollector collector;
collector.VisitStmt(stmt);
node.loads = collector.GetLoads();
node.stores = collector.GetStores();
stmt_to_node_ids_.insert({stmt, node.id});
nodes_.insert({node.id, node});
}
Expand Down Expand Up @@ -230,10 +232,10 @@ void DataDependencyGraph::Print(int log_level) const {
auto stmt = node.stmt;
VLOG(log_level) << "Node" << id << " stmt: " << stmt;
for (auto load : node.loads) {
VLOG(log_level) << "Load: " << load.data;
VLOG(log_level) << "Load: " << load.name;
}
for (auto store : node.stores) {
VLOG(log_level) << "Store: " << store.data;
VLOG(log_level) << "Store: " << store.name;
}
auto it = in_edges_.find(id);
if (it != in_edges_.end()) {
Expand All @@ -243,7 +245,7 @@ void DataDependencyGraph::Print(int log_level) const {
VLOG(log_level) << "In Edge: \n"
<< nodes_.at(e.id).stmt << " dst: \n"
<< nodes_.at(id).stmt << " DepData:\n"
<< value.first.data << " DepKind: " << dep_kind;
<< value.first.name << " DepKind: " << dep_kind;
}
it = out_edges_.find(id);
if (it != out_edges_.end()) {
Expand All @@ -253,7 +255,7 @@ void DataDependencyGraph::Print(int log_level) const {
VLOG(log_level) << "Out Edge: \n"
<< nodes_.at(id).stmt << " dst: \n"
<< nodes_.at(e.id).stmt << " DepData:\n"
<< value.first.data << " DepKind: " << dep_kind;
<< value.first.name << " DepKind: " << dep_kind;
}
}
}
Expand Down
14 changes: 5 additions & 9 deletions paddle/cinn/ir/ir_analyzer/data_dependency_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ namespace analyzer {
// Dep detail: RAW | WAW | WAR
enum class DepKind { DEP, NO_DEP };

// Var or Tensor
// Var or Tensor name
struct DepData {
std::string name;
ir::Expr data;
};

struct StmtCompare {
Expand All @@ -45,10 +44,7 @@ struct StmtCompare {

struct DepDataCompare {
bool operator()(const DepData& a, const DepData& b) const {
if (!a.data.defined() || !b.data.defined()) {
return a.name < b.name;
}
return !(a.data.get() == b.data.get() || a.name == b.name);
return a.name < b.name;
}
};

Expand All @@ -71,14 +67,14 @@ class DataDependencyGraph {
BuildGraphByStmts();
}

// Returns DepKind::Dep if there is a path in the data dependency graph from
// Returns DepKind::DEP if there is a path in the data dependency graph from
// node src to node dst. Returns DepKind::NO_DEP otherwise. src and dst, are
// expected to be from the same block.
DepKind HasDependency(const ir::stmt::StmtRef& src,
const ir::stmt::StmtRef& dst) const;

// Node represents a node in the graph. A Node is a stmt which contains
// loads/stores,
// loads/stores.
struct Node {
// The unique identifier of this node in the graph.
unsigned id;
Expand All @@ -90,7 +86,7 @@ class DataDependencyGraph {
std::set<DepData, DepDataCompare> stores;

Node() = default;
Node(unsigned id, const ir::stmt::StmtRef& stmt) : id(id), stmt(stmt) {}
Node(unsigned id, const ir::stmt::StmtRef& stmt);
};

struct Edge {
Expand Down
Loading