@@ -229,163 +229,6 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> {
229229 }
230230};
231231
232- using TENSOR_LOOP = std::pair<ir::Expr, std::vector<ir::Expr>>;
233- class CollectTensorLoopVisitor : public ir ::IRMutator<> {
234- public:
235- void operator ()(ir::Expr *expr) { ir::IRMutator<>::Visit (expr, expr); }
236-
237- private:
238- void Visit (const ir::Store *op, Expr *expr) override {
239- auto tensor = op->tensor .as_tensor_ref ();
240- // if buffer defined and buffer is not Heap.
241- if (tensor->buffer .defined () &&
242- tensor->buffer ->memory_type != ir::MemoryType::Heap) {
243- if (buffer_tensor_loop_map_.count (tensor->buffer ->name )) {
244- buffer_tensor_loop_map_[tensor->buffer ->name ].push_back (
245- std::make_pair (*expr, loops_));
246- } else {
247- buffer_tensor_loop_map_[tensor->buffer ->name ] = {
248- std::make_pair (*expr, loops_)};
249- }
250- }
251-
252- IRMutator::Visit (op, expr);
253- }
254-
255- void Visit (const ir::Load *op, Expr *expr) override {
256- if (op->is_addr_scalar ()) {
257- return ;
258- }
259- auto tensor = op->tensor .as_tensor_ref ();
260- // if buffer defined and buffer is not Heap.
261- if (tensor->buffer .defined () &&
262- tensor->buffer ->memory_type != ir::MemoryType::Heap) {
263- if (buffer_tensor_loop_map_.count (tensor->buffer ->name )) {
264- buffer_tensor_loop_map_[tensor->buffer ->name ].push_back (
265- std::make_pair (*expr, loops_));
266- } else {
267- buffer_tensor_loop_map_[tensor->buffer ->name ] = {
268- std::make_pair (*expr, loops_)};
269- }
270- }
271-
272- IRMutator::Visit (op, expr);
273- }
274-
275- void Visit (const ir::For *op, Expr *expr) override {
276- loops_.push_back (*expr);
277- IRMutator::Visit (op, expr);
278- loops_.pop_back ();
279- }
280-
281- void Visit (const ir::PolyFor *op, Expr *expr) override {
282- LOG (FATAL) << " Unkown PolyFor!" ;
283- }
284-
285- public:
286- std::vector<ir::Expr> loops_;
287- std::unordered_map<std::string, std::vector<TENSOR_LOOP>>
288- buffer_tensor_loop_map_;
289- };
290-
291- void UpdateBufferAxisPassOld (ir::Expr *expr) {
292- CollectTensorLoopVisitor collect_tensor_loop_visitor;
293- collect_tensor_loop_visitor (expr);
294-
295- auto buffer_tensor_loop = collect_tensor_loop_visitor.buffer_tensor_loop_map_ ;
296-
297- for (auto &tmp : buffer_tensor_loop) {
298- auto tensor_loop_v = tmp.second ;
299-
300- auto &front = tensor_loop_v.front ();
301- int count = tensor_loop_v.size () > 1 ? front.second .size () : 0 ;
302- for (int idx = 1 ; idx < tensor_loop_v.size (); ++idx) {
303- auto &other = tensor_loop_v[idx];
304- for (int idy = 0 ;
305- idy < std::min (front.second .size (), other.second .size ());
306- ++idy) {
307- if (front.second [idy] != other.second [idy]) {
308- count = std::min (count, idy);
309- break ;
310- }
311- }
312- }
313-
314- auto get_thread_bind_var = [](const std::vector<ir::Expr> &loops) {
315- // threadidx loop_var,extent.
316- using ThreadLoopVarExtentMap =
317- std::unordered_map<std::string, std::pair<std::string, int >>;
318- ThreadLoopVarExtentMap thread_loop_var_exent_map;
319- for (auto loop : loops) {
320- auto loop_ir = loop.As <ir::For>();
321- CHECK (loop_ir);
322- if (loop_ir->is_gpu_thread_binded ()) {
323- std::string axis = " " ;
324- if (loop_ir->bind_info ().offset == 0 ) {
325- axis = " threadIdx.x" ;
326- } else if (loop_ir->bind_info ().offset == 1 ) {
327- axis = " threadIdx.y" ;
328- } else {
329- axis = " threadIdx.z" ;
330- }
331- // insert gpu thread loop var.
332- if (thread_loop_var_exent_map.count (axis)) {
333- auto &loop_var_extent = thread_loop_var_exent_map[axis];
334- if (loop_var_extent.second >= loop_ir->extent .as_int32 ()) {
335- thread_loop_var_exent_map[axis] = std::make_pair (
336- loop_ir->loop_var ->name , loop_ir->extent .as_int32 ());
337- }
338- } else {
339- thread_loop_var_exent_map[axis] = std::make_pair (
340- loop_ir->loop_var ->name , loop_ir->extent .as_int32 ());
341- }
342- }
343- }
344-
345- std::unordered_set<std::string> loop_var_map;
346- for (auto &tmp : thread_loop_var_exent_map) {
347- loop_var_map.insert (tmp.second .first );
348- }
349-
350- return loop_var_map;
351- };
352-
353- auto load = front.first .As <ir::Load>();
354- auto store = front.first .As <ir::Store>();
355- auto tensor =
356- load ? load->tensor .as_tensor_ref () : store->tensor .as_tensor_ref ();
357- // find store and load keep loop for shared
358- std::vector<std::unordered_set<std::string>> keep_loop_vars;
359- if (tensor->buffer ->memory_type == ir::MemoryType::GPUShared) {
360- for (auto &tensor_loop : tensor_loop_v) {
361- keep_loop_vars.push_back (get_thread_bind_var (tensor_loop.second ));
362- }
363- CHECK_EQ (keep_loop_vars.size (), tensor_loop_v.size ());
364- }
365-
366- auto &loops = front.second ;
367- for (int idx = 0 ; idx < count; ++idx) {
368- auto loop_expr = loops[idx];
369- auto loop_ir = loop_expr.As <ir::For>();
370- auto loop_var = loop_ir->loop_var ;
371-
372- for (int idy = 0 ; idy < tensor_loop_v.size (); ++idy) {
373- auto expr = tensor_loop_v[idy].first ;
374- auto load = expr.As <ir::Load>();
375- auto store = expr.As <ir::Store>();
376- if (keep_loop_vars.size () == 0 ||
377- !keep_loop_vars[idy].count (loop_var->name )) {
378- auto &indices = load ? load->indices : store->indices ;
379- for (auto &indice : indices) {
380- optim::ReplaceVarWithExpr (&indice, loop_var, ir::Expr (0 ));
381- indice = cinn::common::AutoSimplify (indice);
382- }
383- }
384- }
385- }
386- }
387- }
388-
389232class ReplaceLoopVarToGpu : public ir ::IRMutator<> {
390233 public:
391234 void operator ()(Expr *expr) { ir::IRMutator<>::Visit (expr, expr); }
@@ -586,7 +429,6 @@ void OptimizeExprGPU(Expr *expr) {
586429
587430 // resize buffer axis
588431 UpdateBufferAxisPass (expr);
589- // UpdateBufferAxisPassOld(expr);
590432
591433 // replace var name with block/thread
592434 ReplaceLoopVarToGpu replace_loop_var_to_gpu;
0 commit comments