Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 0 additions & 158 deletions paddle/cinn/optim/transform_gpu_forloop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,163 +229,6 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> {
}
};

using TENSOR_LOOP = std::pair<ir::Expr, std::vector<ir::Expr>>;
class CollectTensorLoopVisitor : public ir::IRMutator<> {
public:
void operator()(ir::Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }

private:
void Visit(const ir::Store *op, Expr *expr) override {
auto tensor = op->tensor.as_tensor_ref();
// if buffer defined and buffer is not Heap.
if (tensor->buffer.defined() &&
tensor->buffer->memory_type != ir::MemoryType::Heap) {
if (buffer_tensor_loop_map_.count(tensor->buffer->name)) {
buffer_tensor_loop_map_[tensor->buffer->name].push_back(
std::make_pair(*expr, loops_));
} else {
buffer_tensor_loop_map_[tensor->buffer->name] = {
std::make_pair(*expr, loops_)};
}
}

IRMutator::Visit(op, expr);
}

void Visit(const ir::Load *op, Expr *expr) override {
if (op->is_addr_scalar()) {
return;
}
auto tensor = op->tensor.as_tensor_ref();
// if buffer defined and buffer is not Heap.
if (tensor->buffer.defined() &&
tensor->buffer->memory_type != ir::MemoryType::Heap) {
if (buffer_tensor_loop_map_.count(tensor->buffer->name)) {
buffer_tensor_loop_map_[tensor->buffer->name].push_back(
std::make_pair(*expr, loops_));
} else {
buffer_tensor_loop_map_[tensor->buffer->name] = {
std::make_pair(*expr, loops_)};
}
}

IRMutator::Visit(op, expr);
}

void Visit(const ir::For *op, Expr *expr) override {
loops_.push_back(*expr);
IRMutator::Visit(op, expr);
loops_.pop_back();
}

void Visit(const ir::PolyFor *op, Expr *expr) override {
LOG(FATAL) << "Unkown PolyFor!";
}

public:
std::vector<ir::Expr> loops_;
std::unordered_map<std::string, std::vector<TENSOR_LOOP>>
buffer_tensor_loop_map_;
};

void UpdateBufferAxisPassOld(ir::Expr *expr) {
CollectTensorLoopVisitor collect_tensor_loop_visitor;
collect_tensor_loop_visitor(expr);

auto buffer_tensor_loop = collect_tensor_loop_visitor.buffer_tensor_loop_map_;

for (auto &tmp : buffer_tensor_loop) {
auto tensor_loop_v = tmp.second;

auto &front = tensor_loop_v.front();
int count = tensor_loop_v.size() > 1 ? front.second.size() : 0;
for (int idx = 1; idx < tensor_loop_v.size(); ++idx) {
auto &other = tensor_loop_v[idx];
for (int idy = 0;
idy < std::min(front.second.size(), other.second.size());
++idy) {
if (front.second[idy] != other.second[idy]) {
count = std::min(count, idy);
break;
}
}
}

auto get_thread_bind_var = [](const std::vector<ir::Expr> &loops) {
// threadidx loop_var,extent.
using ThreadLoopVarExtentMap =
std::unordered_map<std::string, std::pair<std::string, int>>;
ThreadLoopVarExtentMap thread_loop_var_exent_map;
for (auto loop : loops) {
auto loop_ir = loop.As<ir::For>();
CHECK(loop_ir);
if (loop_ir->is_gpu_thread_binded()) {
std::string axis = "";
if (loop_ir->bind_info().offset == 0) {
axis = "threadIdx.x";
} else if (loop_ir->bind_info().offset == 1) {
axis = "threadIdx.y";
} else {
axis = "threadIdx.z";
}
// insert gpu thread loop var.
if (thread_loop_var_exent_map.count(axis)) {
auto &loop_var_extent = thread_loop_var_exent_map[axis];
if (loop_var_extent.second >= loop_ir->extent.as_int32()) {
thread_loop_var_exent_map[axis] = std::make_pair(
loop_ir->loop_var->name, loop_ir->extent.as_int32());
}
} else {
thread_loop_var_exent_map[axis] = std::make_pair(
loop_ir->loop_var->name, loop_ir->extent.as_int32());
}
}
}

std::unordered_set<std::string> loop_var_map;
for (auto &tmp : thread_loop_var_exent_map) {
loop_var_map.insert(tmp.second.first);
}

return loop_var_map;
};

auto load = front.first.As<ir::Load>();
auto store = front.first.As<ir::Store>();
auto tensor =
load ? load->tensor.as_tensor_ref() : store->tensor.as_tensor_ref();
// find store and load keep loop for shared
std::vector<std::unordered_set<std::string>> keep_loop_vars;
if (tensor->buffer->memory_type == ir::MemoryType::GPUShared) {
for (auto &tensor_loop : tensor_loop_v) {
keep_loop_vars.push_back(get_thread_bind_var(tensor_loop.second));
}
CHECK_EQ(keep_loop_vars.size(), tensor_loop_v.size());
}

auto &loops = front.second;
for (int idx = 0; idx < count; ++idx) {
auto loop_expr = loops[idx];
auto loop_ir = loop_expr.As<ir::For>();
auto loop_var = loop_ir->loop_var;

for (int idy = 0; idy < tensor_loop_v.size(); ++idy) {
auto expr = tensor_loop_v[idy].first;
auto load = expr.As<ir::Load>();
auto store = expr.As<ir::Store>();
if (keep_loop_vars.size() == 0 ||
!keep_loop_vars[idy].count(loop_var->name)) {
auto &indices = load ? load->indices : store->indices;
for (auto &indice : indices) {
optim::ReplaceVarWithExpr(&indice, loop_var, ir::Expr(0));
indice = cinn::common::AutoSimplify(indice);
}
}
}
}
}
}

class ReplaceLoopVarToGpu : public ir::IRMutator<> {
public:
void operator()(Expr *expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand Down Expand Up @@ -586,7 +429,6 @@ void OptimizeExprGPU(Expr *expr) {

// resize buffer axis
UpdateBufferAxisPass(expr);
// UpdateBufferAxisPassOld(expr);

// replace var name with block/thread
ReplaceLoopVarToGpu replace_loop_var_to_gpu;
Expand Down
100 changes: 4 additions & 96 deletions paddle/cinn/optim/update_buffer_axis_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,80 +61,6 @@ void FormalizeSingleIndex(const ir::Tensor& tensor,
}
}

/**
* This is a template pass to update the buffer access when using
* single axis of a mult-dim tensor. For example, if the tensor t
* t.shape = [2, 3, 4] and the buffer access is t[12 * k]
* it is same as t[k, 0, 0]. It is easy for human to understand
* they are the same but not easy for compiler.
*
* This class check the buffer access are the same and update those
* same buffer access with the same index expr.
*
* Note! this is a temporary solution. Our symbolic simplify is not
* powerful to simplify the 12 * k / 4 % 3 and so on. So we only handle
* the simplest case. We can modify our class when we can simplify the
* 12 * k / 4 % 3 well.
*/
class AnalyzeSingleAxisOfMultDimTensor : public ir::IRMutator<> {
public:
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

void Visit(const ir::Store* op, Expr* expr) override {
ir::Store* store = expr->As<ir::Store>();
ir::Tensor tensor = store->tensor.as_tensor_ref();
AnalyzeSingleAxisAccess(store->indices, tensor);
ir::IRMutator<>::Visit(op, expr);
}

// Analyze the buffer access inside load
void Visit(const ir::Load* op, Expr* expr) override {
ir::Load* load = expr->As<ir::Load>();
ir::Tensor tensor = load->tensor.as_tensor_ref();
AnalyzeSingleAxisAccess(load->indices, tensor);
ir::IRMutator<>::Visit(op, expr);
}

void AnalyzeSingleAxisAccess(const std::vector<Expr>& indices,
const ir::Tensor& tensor) {
if (!tensor->buffer.defined() ||
tensor->buffer->memory_type == ir::MemoryType::Heap ||
tensor->buffer->memory_type == ir::MemoryType::GPUShared) {
return;
}
CHECK(indices.size() > 0) << "Buffer access indices is empty";
const std::string& buffer_name = tensor->buffer->name;
const std::vector<ir::Expr>& shape = tensor->shape;

ir::Expr index_expr;
if (indices.size() == 1 && shape.size() > 1) {
index_expr = indices[0];
} else if (indices.size() == shape.size()) {
ir::Expr mul = Expr(1);
index_expr = indices.back();
for (int i = static_cast<int>(indices.size()) - 2; i >= 0; --i) {
mul = ir::Mul::Make(shape[i + 1], mul);
ir::Expr cur = ir::Mul::Make(indices[i], mul);
index_expr = ir::Add::Make(cur, index_expr);
}
}
index_expr = common::AutoSimplify(index_expr);

if (!buffer_name_to_same_single_axis.count(buffer_name)) {
buffer_name_to_same_single_axis[buffer_name] = index_expr;
return;
} else {
const ir::Expr& stored_index_expr =
buffer_name_to_same_single_axis[buffer_name];
if (!ExprMathEqual(index_expr, stored_index_expr)) {
buffer_name_to_same_single_axis.erase(buffer_name);
}
}
}

std::unordered_map<std::string, ir::Expr> buffer_name_to_same_single_axis;
};

class AnalyzeBufferAxis : public ir::IRMutator<> {
public:
void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
Expand Down Expand Up @@ -260,11 +186,9 @@ class ReplaceSameAxisToZero : public ir::IRMutator<> {
public:
ReplaceSameAxisToZero(
const std::unordered_map<std::string, std::map<int, ir::Expr>>&
buffer_name_access_same_index_expr,
const std::unordered_map<std::string, ir::Expr>&
buffer_name_to_same_single_axis)
: buffer_name_access_same_index_expr_(buffer_name_access_same_index_expr),
buffer_name_to_same_single_axis_(buffer_name_to_same_single_axis) {}
buffer_name_access_same_index_expr)
: buffer_name_access_same_index_expr_(
buffer_name_access_same_index_expr) {}

void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

Expand Down Expand Up @@ -303,29 +227,15 @@ class ReplaceSameAxisToZero : public ir::IRMutator<> {
}
return;
}
if (buffer_name_to_same_single_axis_.count(buffer_name)) {
indices->clear();
indices->push_back(ir::Expr(0));
return;
}
}

const std::unordered_map<std::string, std::map<int, ir::Expr>>&
buffer_name_access_same_index_expr_;
const std::unordered_map<std::string, ir::Expr>&
buffer_name_to_same_single_axis_;
};

void UpdateBufferAxisPass(ir::Expr* expr) {
VLOG(6) << "Before UpdateBufferAxisPass, Expr = \n" << *expr;

// AnalyzeSingleAxisOfMultDimTensor singler_axis_analyzer;
// singler_axis_analyzer(expr);
// for (auto p : singler_axis_analyzer.buffer_name_to_same_single_axis) {
// VLOG(6) << "Single axis Buffer name: " << p.first;
// VLOG(6) << "Single Expr: " << p.second;
// }
std::unordered_map<std::string, ir::Expr> dump;
AnalyzeBufferAxis buffer_axis_analyzer;
buffer_axis_analyzer(expr);
for (auto p : buffer_axis_analyzer.buffer_name_access_same_index_expr) {
Expand All @@ -336,9 +246,7 @@ void UpdateBufferAxisPass(ir::Expr* expr) {
}

ReplaceSameAxisToZero replacer(
buffer_axis_analyzer.buffer_name_access_same_index_expr,
// singler_axis_analyzer.buffer_name_to_same_single_axis);
dump);
buffer_axis_analyzer.buffer_name_access_same_index_expr);
replacer(expr);
VLOG(6) << "After UpdateBufferAxisPass, Expr = \n" << *expr;
}
Expand Down