Skip to content

Commit 2a0c2f8

Browse files
authored
Add walkthrough test on python and debug (PaddlePaddle#204)
1 parent 4dd175d commit 2a0c2f8

File tree

13 files changed

+212
-22
lines changed

13 files changed

+212
-22
lines changed

cinn/common/graph_utils.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,17 @@ class GraphNode;
3131
*/
3232
class GraphEdge : public Object {
3333
public:
34-
GraphEdge(GraphNode* source, GraphNode* sink) : source_(source), sink_(sink) {}
34+
GraphEdge(GraphNode* source, GraphNode* sink, int index = -1) : source_(source), sink_(sink), index_(index) {}
3535

3636
GraphNode* source() const { return source_; }
3737
GraphNode* sink() const { return sink_; }
3838
const char* type_info() const override { return __type_info__; }
39+
int index() const { return index_; }
3940

4041
private:
42+
//! the index in sink node's inlinks_ or source node's outlinks_
43+
//! this is used to keep the input/output tensor's order of operator node
44+
int index_{-1};
4145
//! Source of this edge.
4246
GraphNode* source_{};
4347
//! End of this edge.
@@ -64,9 +68,10 @@ class GraphNode : public Object {
6468
EdgeT *a, *b;
6569
CHECK(other);
6670
CHECK_NE(other, this) << "cannot link to itself";
67-
auto edge = make_shared<GraphEdge>(this, other);
68-
auto edge1 = make_shared<GraphEdge>(this, other);
69-
71+
auto edge = make_shared<GraphEdge>(this, other, index_outlinks);
72+
auto edge1 = make_shared<GraphEdge>(this, other, other->index_inlinks);
73+
index_outlinks++;
74+
other->index_inlinks++;
7075
outlinks_.insert(edge);
7176
other->inlinks_.insert(edge1);
7277

@@ -140,6 +145,9 @@ class GraphNode : public Object {
140145
std::set<common::Shared<GraphEdge>, GraphEdgeCompare> outlinks_;
141146

142147
mutable int visited_time_{};
148+
//! used to mark the index of node's input/output tensors
149+
int index_inlinks{0};
150+
int index_outlinks{0};
143151
};
144152

145153
/**

cinn/frontend/syntax.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "cinn/frontend/syntax.h"
2+
#include "cinn/hlir/framework/node.h"
23
#include "cinn/hlir/framework/op.h"
34
#include "cinn/utils/string.h"
45

@@ -34,6 +35,38 @@ Variable Program::add(const Variable& a, const Variable& b) {
3435
return instr.GetOutputs()[0];
3536
}
3637

38+
Variable Program::relu(const Variable& a) {
39+
Instruction instr("relu");
40+
instr.SetInputs({a});
41+
AddInstruction(instr);
42+
return instr.GetOutputs()[0];
43+
}
44+
45+
std::vector<Variable> Program::conv2d(
46+
const Variable& a,
47+
const Variable& b,
48+
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store) {
49+
Instruction instr("conv2d");
50+
instr.SetInputs({a, b});
51+
for (auto& iter : attr_store) {
52+
instr.SetAttr(iter.first, iter.second);
53+
}
54+
AddInstruction(instr);
55+
return instr.GetOutputs();
56+
}
57+
58+
Variable Program::batchnorm(const Variable& a,
59+
const Variable& b,
60+
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store) {
61+
Instruction instr("batchnorm");
62+
instr.SetInputs({a, b});
63+
for (auto& iter : attr_store) {
64+
instr.SetAttr(iter.first, iter.second);
65+
}
66+
AddInstruction(instr);
67+
return instr.GetOutputs()[0];
68+
}
69+
3770
Instruction& Program::operator[](size_t i) {
3871
CHECK_LT(i, instrs.size());
3972
return instrs[i];

cinn/frontend/syntax.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,41 @@ struct Program {
145145
*/
146146
Variable add(const Variable& a, const Variable& b);
147147

148+
/**
149+
* Apply Rectified Linear Unit on input Variable.
150+
* Actually apply: outupt = max(input,0)
151+
*
152+
* @param a The first variable.
153+
* @return The result.
154+
*/
155+
Variable relu(const Variable& a);
156+
157+
/**
158+
* The convolution2D layer calculates the output based on the input, filter
159+
* and strides, paddings, dilations, groups parameters.
160+
*
161+
* @param a The first variable input.
162+
* @param b The second variable filter(weights).
163+
* @param attr_store The params like padding, stride, dilation, etc.
164+
* @return The result.
165+
*/
166+
std::vector<Variable> conv2d(const Variable& a,
167+
const Variable& b,
168+
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store);
169+
170+
/**
171+
* The batchnorm layer can be used as a normalizer function
172+
* for convolution or fully_connected operations.
173+
*
174+
* @param a The first variable input.
175+
* @param b The second variable filter(weights).
176+
* @param attr_store The params like eplison.
177+
* @return The result.
178+
*/
179+
Variable batchnorm(const Variable& a,
180+
const Variable& b,
181+
const std::unordered_map<std::string, hlir::framework::NodeAttr::attr_t>& attr_store);
182+
148183
/**
149184
* Get \p i-th instruction.
150185
*/

cinn/hlir/framework/graph_compiler.cc

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@ namespace cinn {
99
namespace hlir {
1010
namespace framework {
1111

12+
void GraphCompiler::PrintFunc() {
13+
auto [nodes, edges] = graph_->topological_order();
14+
for (auto& n : nodes) {
15+
auto* node = n->safe_as<Node>();
16+
if (node) {
17+
auto lowered_func = GetOpFunc(node);
18+
}
19+
}
20+
}
21+
1222
std::unique_ptr<Program> GraphCompiler::Build() {
1323
auto [nodes, edges] = graph_->topological_order();
1424
for (auto& n : nodes) {
@@ -52,7 +62,7 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
5262
auto& dtype_dict = graph_->GetAttrs<std::unordered_map<std::string, Type>>("inferdtype");
5363
std::vector<ir::Tensor> inputs;
5464
std::vector<common::CINNValue> cinn_inputs;
55-
for (auto& i : node->inlinks()) {
65+
for (auto& i : node->inlinks_in_order()) {
5666
std::string input_id = i->source()->as<NodeData>()->id();
5767
std::vector<int> in_shape = shape_dict.at(input_id);
5868
Type dtype = dtype_dict.at(input_id);
@@ -63,7 +73,7 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
6373
cinn_inputs.push_back(common::CINNValue(temp));
6474
}
6575
std::vector<Type> out_types;
66-
for (auto& out : node->outlinks()) {
76+
for (auto& out : node->outlinks_in_order()) {
6777
std::string out_id = out->sink()->safe_as<NodeData>()->id();
6878
Type dtype = dtype_dict.at(out_id);
6979
out_types.push_back(dtype);
@@ -80,21 +90,21 @@ ir::LoweredFunc GraphCompiler::GetOpFunc(const Node* node) {
8090
}
8191

8292
auto func = Lower(GenOpFuncName(node), stages, inputs);
83-
93+
LOG(INFO) << "The function of node [" << node->attrs.node_name << "] is: " << func;
8494
return func;
8595
}
8696

8797
std::vector<std::string> GraphCompiler::OpGetInputNames(const Node* node) const {
8898
std::vector<std::string> res;
89-
for (auto& i : node->inlinks()) {
99+
for (auto& i : node->inlinks_in_order()) {
90100
res.push_back(i->source()->as<NodeData>()->id());
91101
}
92102
return res;
93103
}
94104

95105
std::vector<std::string> GraphCompiler::OpGetOutputNames(const Node* node) const {
96106
std::vector<std::string> res;
97-
for (auto& i : node->outlinks()) {
107+
for (auto& i : node->outlinks_in_order()) {
98108
res.push_back(i->sink()->as<NodeData>()->id());
99109
}
100110
return res;

cinn/hlir/framework/graph_compiler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class GraphCompiler final {
5959

6060
std::unique_ptr<Program> Build();
6161

62+
void PrintFunc();
63+
6264
private:
6365
ir::LoweredFunc GetOpFunc(const Node* node);
6466

cinn/hlir/framework/node.cc

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#include "cinn/hlir/framework/node.h"
2+
#include <algorithm>
23

34
namespace cinn {
45
namespace hlir {
56
namespace framework {
67

7-
std::tuple<common::GraphEdge *, common::GraphEdge *> Node::LinkTo(NodeData *other) {
8+
std::tuple<common::GraphEdge*, common::GraphEdge*> Node::LinkTo(NodeData* other) {
89
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
910
}
1011

11-
std::tuple<common::GraphEdge *, common::GraphEdge *> NodeData::LinkTo(Node *other) {
12+
std::tuple<common::GraphEdge*, common::GraphEdge*> NodeData::LinkTo(Node* other) {
1213
return this->common::GraphNode::LinkTo(other->as<common::GraphNode>());
1314
}
1415

@@ -49,6 +50,35 @@ std::ostream &operator<<(std::ostream &os, const NodeAttr &node_attr) {
4950
return os;
5051
}
5152

53+
//! Using index to sort the input/output tensors
54+
bool edge_index_compare(const common::Shared<common::GraphEdge>& a, const common::Shared<common::GraphEdge>& b) {
55+
return a->index() < b->index();
56+
}
57+
58+
const std::vector<common::Shared<common::GraphEdge>>& Node::inlinks_in_order() const {
59+
if (inlinks_in_order_.empty()) {
60+
for (auto& in_edge : this->inlinks()) {
61+
inlinks_in_order_.push_back(in_edge);
62+
CHECK_GE(in_edge->index(), 0) << "The index of a node's inlinks should be >= 0! Now index is: "
63+
<< in_edge->index() << ". Please check.";
64+
}
65+
std::sort(inlinks_in_order_.begin(), inlinks_in_order_.end(), edge_index_compare);
66+
}
67+
return inlinks_in_order_;
68+
}
69+
70+
const std::vector<common::Shared<common::GraphEdge>>& Node::outlinks_in_order() const {
71+
if (outlinks_in_order_.empty()) {
72+
for (auto& out_edge : this->outlinks()) {
73+
outlinks_in_order_.push_back(out_edge);
74+
CHECK_GE(out_edge->index(), 0) << "The index of a node's outlinks should be >= 0! Now index is: "
75+
<< out_edge->index() << ". Please check.";
76+
}
77+
std::sort(outlinks_in_order_.begin(), outlinks_in_order_.end(), edge_index_compare);
78+
}
79+
return outlinks_in_order_;
80+
}
81+
5282
} // namespace framework
5383
} // namespace hlir
5484
} // namespace cinn

cinn/hlir/framework/node.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
#include "cinn/common/graph_utils.h"
11+
#include "cinn/common/shared.h"
1112
#include "cinn/hlir/framework/op.h"
1213

1314
namespace cinn {
@@ -75,6 +76,12 @@ class Node : public common::GraphNode {
7576
*/
7677
NodeAttr attrs;
7778

79+
//! Get the input tensors in order to match tensors correctly.
80+
const std::vector<common::Shared<common::GraphEdge>> &inlinks_in_order() const;
81+
82+
//! Get the output tensors in order to match tensors correctly.
83+
const std::vector<common::Shared<common::GraphEdge>> &outlinks_in_order() const;
84+
7885
inline const Operator *op() const { return this->attrs.op; }
7986

8087
inline bool is_variable() { return (this->attrs.op == nullptr); }
@@ -95,6 +102,8 @@ class Node : public common::GraphNode {
95102
* \brief The unique id of the node.
96103
*/
97104
std::string id_;
105+
mutable std::vector<common::Shared<common::GraphEdge>> outlinks_in_order_{};
106+
mutable std::vector<common::Shared<common::GraphEdge>> inlinks_in_order_{};
98107
};
99108

100109
/**

cinn/hlir/framework/pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "cinn/hlir/framework/pass.h"
2+
#include "cinn/hlir/pass/use_pass.h"
23

34
namespace cinn {
45
namespace hlir {

cinn/hlir/op/nn.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,21 @@ std::vector<std::vector<int>> InferShapeForConv2d(const std::vector<std::vector<
169169
CHECK_EQ(inputs_shape[0].size(), 4) << "The first input tensor's shape size of conv2d op is not 4! Please check.";
170170
int out_shape_h = (inputs_shape[0][2] - ((inputs_shape[1][2] - 1) * dilation + 1) + 2 * padding[0]) / stride[0] + 1;
171171
int out_shape_w = (inputs_shape[0][3] - ((inputs_shape[1][3] - 1) * dilation + 1) + 2 * padding[1]) / stride[1] + 1;
172-
std::vector<std::vector<int>> res{{inputs_shape[0][0], inputs_shape[1][0], out_shape_h, out_shape_w}};
172+
std::vector<std::vector<int>> res{{inputs_shape[0][0],
173+
inputs_shape[0][1],
174+
inputs_shape[0][2] + 2 * padding[0],
175+
inputs_shape[0][3] + 2 * padding[1]},
176+
{inputs_shape[1][0],
177+
inputs_shape[1][1],
178+
(inputs_shape[1][2] - 1) * dilation + 1,
179+
(inputs_shape[1][3] - 1) * dilation + 1},
180+
{inputs_shape[0][0], inputs_shape[1][0], out_shape_h, out_shape_w}};
173181
return res;
174182
}
175183

176184
std::vector<Type> InferDtypeForConv2d(const std::vector<Type> &inputs_type, const framework::NodeAttr &attrs) {
177185
CHECK(!inputs_type.empty()) << "The input's type size is 0! Please check again.";
178-
std::vector<Type> res{inputs_type[0]};
186+
std::vector<Type> res{inputs_type[0], inputs_type[1], inputs_type[0]};
179187
return res;
180188
}
181189

@@ -254,7 +262,7 @@ CINN_REGISTER_HELPER(nn_ops) {
254262
CINN_REGISTER_OP(conv2d)
255263
.describe("Do a 2-D convolution with an NCHW-layout.")
256264
.set_num_inputs(2) // here we consider filter as anohter input
257-
.set_num_outputs(1)
265+
.set_num_outputs(3)
258266
.set_attr<cinn::hlir::framework::StrategyFunction>("CINNStrategy", cinn::hlir::op::StrategyForConv2d)
259267
.set_attr("infershape", std::function(cinn::hlir::op::InferShapeForConv2d))
260268
.set_attr("inferdtype", std::function(cinn::hlir::op::InferDtypeForConv2d))

cinn/pybind/frontend.cc

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,35 @@
1+
#include <pybind11/functional.h>
2+
#include <pybind11/operators.h>
13
#include <pybind11/pybind11.h>
4+
#include <pybind11/stl.h>
5+
#include "cinn/common/type.h"
26
#include "cinn/frontend/syntax.h"
7+
#include "cinn/hlir/framework/graph.h"
8+
#include "cinn/hlir/framework/graph_compiler.h"
9+
#include "cinn/hlir/framework/pass.h"
310
#include "cinn/hlir/op/use_ops.h"
411
#include "cinn/utils/string.h"
512

613
namespace cinn::pybind {
7-
14+
using common::Type;
15+
using frontend::Placeholder;
816
namespace py = pybind11;
917
using namespace cinn::frontend; // NOLINT
1018

1119
void BindFrontend(pybind11::module *m) {
1220
py::class_<Variable>(*m, "Variable") //
1321
.def(py::init<const std::string &>(), py::arg("id") = "")
1422
.def("__str__", [](Variable &self) { return self->id; })
15-
.def("__repr__", [](Variable &self) { return utils::GetStreamCnt(self); });
23+
.def("__repr__", [](Variable &self) { return utils::GetStreamCnt(self); })
24+
.def("set_type",
25+
[](Variable &self, const Type &type) {
26+
self->type = type;
27+
return self;
28+
})
29+
.def("set_shape", [](Variable &self, const std::vector<int> &shape) {
30+
self->shape = shape;
31+
return self;
32+
});
1633

1734
py::class_<Placeholder>(*m, "Placeholder") //
1835
.def(py::init<const common::Type &, const std::vector<int> &, std::string_view>(),
@@ -45,7 +62,17 @@ void BindFrontend(pybind11::module *m) {
4562
.def(py::init<>())
4663
.def("size", &Program::size)
4764
.def("__getitem__", [](Program &self, int idx) { return self[idx]; })
48-
.def("add", &Program::add);
65+
.def("add", &Program::add)
66+
.def("relu", &Program::relu)
67+
.def("conv2d", &Program::conv2d)
68+
.def("batchnorm", &Program::batchnorm)
69+
.def("print_func", [](Program &self, const common::Target &target) {
70+
std::shared_ptr<hlir::framework::Graph> g(new hlir::framework::Graph(self));
71+
hlir::framework::ApplyPass(g.get(), "InferShape");
72+
std::shared_ptr<hlir::framework::Scope> scope = hlir::framework::BuildScope(target, g);
73+
hlir::framework::GraphCompiler gc(target, scope, g);
74+
gc.PrintFunc();
75+
});
4976
} // namespace frontend
5077

5178
} // namespace cinn::pybind

0 commit comments

Comments
 (0)