Skip to content

Commit c3c57fd

Browse files
committed
Refactor buffer index calculation in TileLangStorageAccessVisitor to simplify access handling. Removed unused buffer mapping logic, ensuring consistent buffer index generation with a default ramp.
1 parent 006a14d commit c3c57fd

File tree

1 file changed

+1
-33
lines changed

1 file changed

+1
-33
lines changed

src/transform/storage_access.cc

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -322,39 +322,7 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
322322
if (Enabled(buffer_var, scope)) {
323323
ICHECK(allow_append_);
324324
Array<PrimExpr> buffer_indices;
325-
if (buffer_data_to_buffer_.find(GetRef<Var>(buffer_var)) ==
326-
buffer_data_to_buffer_.end()) {
327-
// cannot find buffer map, use the default buffer
328-
buffer_indices = {Ramp(offset, 1, extent)};
329-
} else {
330-
Buffer buffer = buffer_data_to_buffer_.at(GetRef<Var>(buffer_var));
331-
auto buffer_shape = buffer->shape;
332-
// convert 1d offset to multi-dimensional index
333-
auto linear_to_indices = [this](PrimExpr offset,
334-
const Array<PrimExpr> &shape) {
335-
Array<PrimExpr> indices;
336-
PrimExpr remaining = offset;
337-
for (size_t i = 0; i < shape.size(); ++i) {
338-
PrimExpr stride = make_const(DataType::Int(32), 1);
339-
for (size_t j = i + 1; j < shape.size(); ++j) {
340-
stride = stride * shape[j];
341-
}
342-
PrimExpr idx = FloorDiv(remaining, stride);
343-
remaining = FloorMod(remaining, stride);
344-
indices.push_back(analyzer_.Simplify(idx));
345-
}
346-
return indices;
347-
};
348-
Array<PrimExpr> start_indices = linear_to_indices(offset, buffer_shape);
349-
Array<PrimExpr> end_indices =
350-
linear_to_indices(offset + extent, buffer_shape);
351-
for (size_t i = 0; i < buffer_shape.size(); ++i) {
352-
buffer_indices.push_back(
353-
Ramp(start_indices[i], 1,
354-
analyzer_.Simplify(end_indices[i] - start_indices[i])));
355-
}
356-
}
357-
325+
buffer_indices = {Ramp(offset, 1, extent)};
358326
AccessEntry e;
359327
e.threads = env_threads();
360328
e.thread_range = this->ComputeThreadRange(e.threads);

0 commit comments

Comments
 (0)