@@ -38,6 +38,7 @@ using namespace tir;
3838
3939void  TileLangStorageAccessVisitor::VisitExpr_ (const  BufferLoadNode *op) {
4040  Var buf = op->buffer ->data ;
41+   buffer_data_to_buffer_.Set (GetRef<Var>(buf.get ()), op->buffer );
4142  StorageScope scope = GetScope (buf);
4243  if  (Enabled (buf.get (), scope)) {
4344    ICHECK (allow_append_) << GetRef<BufferLoad>(op) << "  " to_string ();
@@ -64,6 +65,7 @@ void TileLangStorageAccessVisitor::VisitStmt_(const BufferStoreNode *op) {
6465  curr_stmt_.stmt  = op;
6566
6667  Var buf = op->buffer ->data ;
68+   buffer_data_to_buffer_.Set (GetRef<Var>(buf.get ()), op->buffer );
6769  StorageScope scope = GetScope (buf);
6870  if  (Enabled (buf.get (), scope)) {
6971    AccessEntry e;
@@ -115,6 +117,15 @@ void TileLangStorageAccessVisitor::VisitStmt_(const LetStmtNode *op) {
115117  this ->VisitStmt (op->body );
116118}
117119
120+ void  TileLangStorageAccessVisitor::VisitStmt_ (const  BlockNode *op) {
121+   auto  block = Downcast<Block>(op);
122+   for  (const  auto  &buffer : block->alloc_buffers ) {
123+     ICHECK (buffer->IsInstance <BufferNode>());
124+     buffer_data_to_buffer_.Set (buffer->data , buffer);
125+   }
126+   IRVisitorWithAnalyzer::VisitStmt_ (op);
127+ }
128+ 
118129void  TileLangStorageAccessVisitor::VisitStmt_ (const  AttrStmtNode *op) {
119130  if  (op->attr_key  == tvm::tir::attr::double_buffer_write) {
120131    ICHECK (double_buffer_write_ == nullptr );
@@ -271,18 +282,27 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
271282      Buffer buffer = load->buffer ;
272283      DataType dtype = buffer->dtype ;
273284      const  VarNode *buffer_var = buffer->data .as <VarNode>();
285+       buffer_data_to_buffer_.Set (GetRef<Var>(buffer_var), buffer);
274286      StorageScope scope = GetScope (GetRef<Var>(buffer_var));
287+       Array<Range> buffer_ranges;
288+       //  from indices to buffer indices
289+       ICHECK (buffer->shape .size () == load->indices .size ());
290+       for  (size_t  i = 0 ; i < buffer->shape .size (); ++i) {
291+         buffer_ranges.push_back (
292+             Range::FromMinExtent (load->indices [i], buffer->shape [i]));
293+       }
275294      if  (Enabled (buffer_var, scope)) {
276295        ICHECK (allow_append_);
277296        AccessEntry e;
278297        e.threads  = env_threads ();
279298        e.thread_range  = this ->ComputeThreadRange (e.threads );
280299        e.dtype  = dtype;
281300        e.buffer  = Downcast<Var>(buffer->data );
282-         e.buffer_indices  = load-> indices ;
301+         e.buffer_ranges  = buffer_ranges ;
283302        for  (const  auto  &index : load->indices ) {
284303          e.touched .push_back (arith::IntSet::Vector (index));
285304        }
305+         e.is_pointer_access  = true ;
286306        e.type  = kRead ;
287307        e.scope  = scope;
288308        curr_stmt_.access .emplace_back (e);
@@ -294,20 +314,54 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
294314  } else  if  (op->op .same_as (builtin::tvm_access_ptr ())) {
295315    ICHECK_EQ (op->args .size (), 5U );
296316    DataType dtype = op->args [0 ].dtype ();
297-     const  VarNode *buffer  = op->args [1 ].as <VarNode>();
317+     const  VarNode *buffer_var  = op->args [1 ].as <VarNode>();
298318    PrimExpr offset = op->args [2 ];
299319    PrimExpr extent = op->args [3 ];
300320    const  IntImmNode *flag = op->args [4 ].as <IntImmNode>();
301-     StorageScope scope = GetScope (GetRef<Var>(buffer ));
321+     StorageScope scope = GetScope (GetRef<Var>(buffer_var ));
302322    //  The buffer scope.
303-     if  (Enabled (buffer , scope)) {
323+     if  (Enabled (buffer_var , scope)) {
304324      ICHECK (allow_append_);
325+       Array<Range> buffer_ranges;
326+       if  (buffer_data_to_buffer_.find (GetRef<Var>(buffer_var)) ==
327+           buffer_data_to_buffer_.end ()) {
328+         //  cannot find buffer map, use the default buffer
329+         buffer_ranges = {Range::FromMinExtent (offset, extent)};
330+       } else  {
331+         Buffer buffer = buffer_data_to_buffer_.at (GetRef<Var>(buffer_var));
332+         auto  buffer_shape = buffer->shape ;
333+         //  convert 1d offset to multi-dimensional index
334+         auto  linear_to_indices = [this ](PrimExpr offset,
335+                                         const  Array<PrimExpr> &shape) {
336+           Array<PrimExpr> indices;
337+           PrimExpr remaining = offset;
338+           for  (size_t  i = 0 ; i < shape.size (); ++i) {
339+             PrimExpr stride = make_const (DataType::Int (32 ), 1 );
340+             for  (size_t  j = i + 1 ; j < shape.size (); ++j) {
341+               stride = stride * shape[j];
342+             }
343+             PrimExpr idx = FloorDiv (remaining, stride);
344+             remaining = FloorMod (remaining, stride);
345+             indices.push_back (analyzer_.Simplify (idx));
346+           }
347+           return  indices;
348+         };
349+         Array<PrimExpr> start_indices = linear_to_indices (offset, buffer_shape);
350+         Array<PrimExpr> end_indices =
351+             linear_to_indices (offset + extent, buffer_shape);
352+         for  (size_t  i = 0 ; i < buffer_shape.size (); ++i) {
353+           buffer_ranges.push_back (Range::FromMinExtent (
354+               start_indices[i],
355+               analyzer_.Simplify (end_indices[i] - start_indices[i])));
356+         }
357+       }
305358      AccessEntry e;
306359      e.threads  = env_threads ();
307360      e.thread_range  = this ->ComputeThreadRange (e.threads );
308361      e.dtype  = dtype;
309-       e.buffer  = Downcast<Var>(op->args [1 ]);
310-       e.buffer_indices  = {offset, extent};
362+       e.buffer  = GetRef<Var>(buffer_var);
363+       e.buffer_ranges  = buffer_ranges;
364+       e.is_pointer_access  = true ;
311365      e.touched  = {
312366          arith::IntSet::FromRange (Range::FromMinExtent (offset, extent))};
313367      e.scope  = scope;
0 commit comments