Skip to content

Commit 54e46e3

Browse files
committed
Clean Debug Code on Previous PRs (#59839)
Clean code of #59209 and #59014
1 parent 142c71a commit 54e46e3

File tree

2 files changed

+4
-254
lines changed

2 files changed

+4
-254
lines changed

paddle/cinn/optim/transform_gpu_forloop.cc

Lines changed: 0 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -229,163 +229,6 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> {
229229
}
230230
};
231231

232-
using TENSOR_LOOP = std::pair<ir::Expr, std::vector<ir::Expr>>;
233-
class CollectTensorLoopVisitor : public ir::IRMutator<> {
234-
public:
235-
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
236-
237-
private:
238-
void Visit(const ir::Store *op, Expr *expr) override {
239-
auto tensor = op->tensor.as_tensor_ref();
240-
// if buffer defined and buffer is not Heap.
241-
if (tensor->buffer.defined() &&
242-
tensor->buffer->memory_type != ir::MemoryType::Heap) {
243-
if (buffer_tensor_loop_map_.count(tensor->buffer->name)) {
244-
buffer_tensor_loop_map_[tensor->buffer->name].push_back(
245-
std::make_pair(*expr, loops_));
246-
} else {
247-
buffer_tensor_loop_map_[tensor->buffer->name] = {
248-
std::make_pair(*expr, loops_)};
249-
}
250-
}
251-
252-
IRMutator::Visit(op, expr);
253-
}
254-
255-
void Visit(const ir::Load *op, Expr *expr) override {
256-
if (op->is_addr_scalar()) {
257-
return;
258-
}
259-
auto tensor = op->tensor.as_tensor_ref();
260-
// if buffer defined and buffer is not Heap.
261-
if (tensor->buffer.defined() &&
262-
tensor->buffer->memory_type != ir::MemoryType::Heap) {
263-
if (buffer_tensor_loop_map_.count(tensor->buffer->name)) {
264-
buffer_tensor_loop_map_[tensor->buffer->name].push_back(
265-
std::make_pair(*expr, loops_));
266-
} else {
267-
buffer_tensor_loop_map_[tensor->buffer->name] = {
268-
std::make_pair(*expr, loops_)};
269-
}
270-
}
271-
272-
IRMutator::Visit(op, expr);
273-
}
274-
275-
void Visit(const ir::For *op, Expr *expr) override {
276-
loops_.push_back(*expr);
277-
IRMutator::Visit(op, expr);
278-
loops_.pop_back();
279-
}
280-
281-
void Visit(const ir::PolyFor *op, Expr *expr) override {
282-
LOG(FATAL) << "Unkown PolyFor!";
283-
}
284-
285-
public:
286-
std::vector<ir::Expr> loops_;
287-
std::unordered_map<std::string, std::vector<TENSOR_LOOP>>
288-
buffer_tensor_loop_map_;
289-
};
290-
291-
void UpdateBufferAxisPassOld(ir::Expr *expr) {
292-
CollectTensorLoopVisitor collect_tensor_loop_visitor;
293-
collect_tensor_loop_visitor(expr);
294-
295-
auto buffer_tensor_loop = collect_tensor_loop_visitor.buffer_tensor_loop_map_;
296-
297-
for (auto &tmp : buffer_tensor_loop) {
298-
auto tensor_loop_v = tmp.second;
299-
300-
auto &front = tensor_loop_v.front();
301-
int count = tensor_loop_v.size() > 1 ? front.second.size() : 0;
302-
for (int idx = 1; idx < tensor_loop_v.size(); ++idx) {
303-
auto &other = tensor_loop_v[idx];
304-
for (int idy = 0;
305-
idy < std::min(front.second.size(), other.second.size());
306-
++idy) {
307-
if (front.second[idy] != other.second[idy]) {
308-
count = std::min(count, idy);
309-
break;
310-
}
311-
}
312-
}
313-
314-
auto get_thread_bind_var = [](const std::vector<ir::Expr> &loops) {
315-
// threadidx loop_var,extent.
316-
using ThreadLoopVarExtentMap =
317-
std::unordered_map<std::string, std::pair<std::string, int>>;
318-
ThreadLoopVarExtentMap thread_loop_var_exent_map;
319-
for (auto loop : loops) {
320-
auto loop_ir = loop.As<ir::For>();
321-
CHECK(loop_ir);
322-
if (loop_ir->is_gpu_thread_binded()) {
323-
std::string axis = "";
324-
if (loop_ir->bind_info().offset == 0) {
325-
axis = "threadIdx.x";
326-
} else if (loop_ir->bind_info().offset == 1) {
327-
axis = "threadIdx.y";
328-
} else {
329-
axis = "threadIdx.z";
330-
}
331-
// insert gpu thread loop var.
332-
if (thread_loop_var_exent_map.count(axis)) {
333-
auto &loop_var_extent = thread_loop_var_exent_map[axis];
334-
if (loop_var_extent.second >= loop_ir->extent.as_int32()) {
335-
thread_loop_var_exent_map[axis] = std::make_pair(
336-
loop_ir->loop_var->name, loop_ir->extent.as_int32());
337-
}
338-
} else {
339-
thread_loop_var_exent_map[axis] = std::make_pair(
340-
loop_ir->loop_var->name, loop_ir->extent.as_int32());
341-
}
342-
}
343-
}
344-
345-
std::unordered_set<std::string> loop_var_map;
346-
for (auto &tmp : thread_loop_var_exent_map) {
347-
loop_var_map.insert(tmp.second.first);
348-
}
349-
350-
return loop_var_map;
351-
};
352-
353-
auto load = front.first.As<ir::Load>();
354-
auto store = front.first.As<ir::Store>();
355-
auto tensor =
356-
load ? load->tensor.as_tensor_ref() : store->tensor.as_tensor_ref();
357-
// find store and load keep loop for shared
358-
std::vector<std::unordered_set<std::string>> keep_loop_vars;
359-
if (tensor->buffer->memory_type == ir::MemoryType::GPUShared) {
360-
for (auto &tensor_loop : tensor_loop_v) {
361-
keep_loop_vars.push_back(get_thread_bind_var(tensor_loop.second));
362-
}
363-
CHECK_EQ(keep_loop_vars.size(), tensor_loop_v.size());
364-
}
365-
366-
auto &loops = front.second;
367-
for (int idx = 0; idx < count; ++idx) {
368-
auto loop_expr = loops[idx];
369-
auto loop_ir = loop_expr.As<ir::For>();
370-
auto loop_var = loop_ir->loop_var;
371-
372-
for (int idy = 0; idy < tensor_loop_v.size(); ++idy) {
373-
auto expr = tensor_loop_v[idy].first;
374-
auto load = expr.As<ir::Load>();
375-
auto store = expr.As<ir::Store>();
376-
if (keep_loop_vars.size() == 0 ||
377-
!keep_loop_vars[idy].count(loop_var->name)) {
378-
auto &indices = load ? load->indices : store->indices;
379-
for (auto &indice : indices) {
380-
optim::ReplaceVarWithExpr(&indice, loop_var, ir::Expr(0));
381-
indice = cinn::common::AutoSimplify(indice);
382-
}
383-
}
384-
}
385-
}
386-
}
387-
}
388-
389232
class ReplaceLoopVarToGpu : public ir::IRMutator<> {
390233
public:
391234
void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
@@ -586,7 +429,6 @@ void OptimizeExprGPU(Expr *expr) {
586429

587430
// resize buffer axis
588431
UpdateBufferAxisPass(expr);
589-
// UpdateBufferAxisPassOld(expr);
590432

591433
// replace var name with block/thread
592434
ReplaceLoopVarToGpu replace_loop_var_to_gpu;

paddle/cinn/optim/update_buffer_axis_pass.cc

Lines changed: 4 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -61,80 +61,6 @@ void FormalizeSingleIndex(const ir::Tensor& tensor,
6161
}
6262
}
6363

64-
/**
65-
* This is a template pass to update the buffer access when using
66-
* single axis of a mult-dim tensor. For example, if the tensor t
67-
* t.shape = [2, 3, 4] and the buffer access is t[12 * k]
68-
* it is same as t[k, 0, 0]. It is easy for human to understand
69-
* they are the same but not easy for compiler.
70-
*
71-
* This class check the buffer access are the same and update those
72-
* same buffer access with the same index expr.
73-
*
74-
* Note! this is a temporary solution. Our symbolic simplify is not
75-
* powerful to simplify the 12 * k / 4 % 3 and so on. So we only handle
76-
* the simplest case. We can modify our class when we can simplify the
77-
* 12 * k / 4 % 3 well.
78-
*/
79-
class AnalyzeSingleAxisOfMultDimTensor : public ir::IRMutator<> {
80-
public:
81-
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
82-
83-
void Visit(const ir::Store* op, Expr* expr) override {
84-
ir::Store* store = expr->As<ir::Store>();
85-
ir::Tensor tensor = store->tensor.as_tensor_ref();
86-
AnalyzeSingleAxisAccess(store->indices, tensor);
87-
ir::IRMutator<>::Visit(op, expr);
88-
}
89-
90-
// Analyze the buffer access inside load
91-
void Visit(const ir::Load* op, Expr* expr) override {
92-
ir::Load* load = expr->As<ir::Load>();
93-
ir::Tensor tensor = load->tensor.as_tensor_ref();
94-
AnalyzeSingleAxisAccess(load->indices, tensor);
95-
ir::IRMutator<>::Visit(op, expr);
96-
}
97-
98-
void AnalyzeSingleAxisAccess(const std::vector<Expr>& indices,
99-
const ir::Tensor& tensor) {
100-
if (!tensor->buffer.defined() ||
101-
tensor->buffer->memory_type == ir::MemoryType::Heap ||
102-
tensor->buffer->memory_type == ir::MemoryType::GPUShared) {
103-
return;
104-
}
105-
CHECK(indices.size() > 0) << "Buffer access indices is empty";
106-
const std::string& buffer_name = tensor->buffer->name;
107-
const std::vector<ir::Expr>& shape = tensor->shape;
108-
109-
ir::Expr index_expr;
110-
if (indices.size() == 1 && shape.size() > 1) {
111-
index_expr = indices[0];
112-
} else if (indices.size() == shape.size()) {
113-
ir::Expr mul = Expr(1);
114-
index_expr = indices.back();
115-
for (int i = static_cast<int>(indices.size()) - 2; i >= 0; --i) {
116-
mul = ir::Mul::Make(shape[i + 1], mul);
117-
ir::Expr cur = ir::Mul::Make(indices[i], mul);
118-
index_expr = ir::Add::Make(cur, index_expr);
119-
}
120-
}
121-
index_expr = common::AutoSimplify(index_expr);
122-
123-
if (!buffer_name_to_same_single_axis.count(buffer_name)) {
124-
buffer_name_to_same_single_axis[buffer_name] = index_expr;
125-
return;
126-
} else {
127-
const ir::Expr& stored_index_expr =
128-
buffer_name_to_same_single_axis[buffer_name];
129-
if (!ExprMathEqual(index_expr, stored_index_expr)) {
130-
buffer_name_to_same_single_axis.erase(buffer_name);
131-
}
132-
}
133-
}
134-
135-
std::unordered_map<std::string, ir::Expr> buffer_name_to_same_single_axis;
136-
};
137-
13864
class AnalyzeBufferAxis : public ir::IRMutator<> {
13965
public:
14066
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
@@ -260,11 +186,9 @@ class ReplaceSameAxisToZero : public ir::IRMutator<> {
260186
public:
261187
ReplaceSameAxisToZero(
262188
const std::unordered_map<std::string, std::map<int, ir::Expr>>&
263-
buffer_name_access_same_index_expr,
264-
const std::unordered_map<std::string, ir::Expr>&
265-
buffer_name_to_same_single_axis)
266-
: buffer_name_access_same_index_expr_(buffer_name_access_same_index_expr),
267-
buffer_name_to_same_single_axis_(buffer_name_to_same_single_axis) {}
189+
buffer_name_access_same_index_expr)
190+
: buffer_name_access_same_index_expr_(
191+
buffer_name_access_same_index_expr) {}
268192

269193
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
270194

@@ -303,29 +227,15 @@ class ReplaceSameAxisToZero : public ir::IRMutator<> {
303227
}
304228
return;
305229
}
306-
if (buffer_name_to_same_single_axis_.count(buffer_name)) {
307-
indices->clear();
308-
indices->push_back(ir::Expr(0));
309-
return;
310-
}
311230
}
312231

313232
const std::unordered_map<std::string, std::map<int, ir::Expr>>&
314233
buffer_name_access_same_index_expr_;
315-
const std::unordered_map<std::string, ir::Expr>&
316-
buffer_name_to_same_single_axis_;
317234
};
318235

319236
void UpdateBufferAxisPass(ir::Expr* expr) {
320237
VLOG(6) << "Before UpdateBufferAxisPass, Expr = \n" << *expr;
321238

322-
// AnalyzeSingleAxisOfMultDimTensor singler_axis_analyzer;
323-
// singler_axis_analyzer(expr);
324-
// for (auto p : singler_axis_analyzer.buffer_name_to_same_single_axis) {
325-
// VLOG(6) << "Single axis Buffer name: " << p.first;
326-
// VLOG(6) << "Single Expr: " << p.second;
327-
// }
328-
std::unordered_map<std::string, ir::Expr> dump;
329239
AnalyzeBufferAxis buffer_axis_analyzer;
330240
buffer_axis_analyzer(expr);
331241
for (auto p : buffer_axis_analyzer.buffer_name_access_same_index_expr) {
@@ -336,9 +246,7 @@ void UpdateBufferAxisPass(ir::Expr* expr) {
336246
}
337247

338248
ReplaceSameAxisToZero replacer(
339-
buffer_axis_analyzer.buffer_name_access_same_index_expr,
340-
// singler_axis_analyzer.buffer_name_to_same_single_axis);
341-
dump);
249+
buffer_axis_analyzer.buffer_name_access_same_index_expr);
342250
replacer(expr);
343251
VLOG(6) << "After UpdateBufferAxisPass, Expr = \n" << *expr;
344252
}

0 commit comments

Comments
 (0)