Skip to content

Commit e6f7ab4

Browse files
committed
add chunk size & mvin type
1 parent d3b9695 commit e6f7ab4

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

mlir/test/lib/Conversion/MemRefToGemmini/TestMemRefToGemminiConversion.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,12 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> {
9595
Value DstPtr = getStridedElementPtr(loc, dstMemRefType, DstMemref, dst_indices, rewriter);
9696
Value main_mem_stride = op.getStride();
9797
int64_t main_mem_stride_val = extractConstantIntValue(main_mem_stride) * elen / 8;
98-
Value num_elt = op.getNumElementsPerStride();
99-
int num_elt_val = extractConstantIntValue(num_elt);
100-
int is_transpose = 0;
101-
if (num_elt_val == 1) // num_elt = 1 means transposed
102-
is_transpose = 1;
98+
Value num_elt_per_stride = op.getNumElementsPerStride();
99+
int chunk_size = extractConstantIntValue(num_elt_per_stride);
100+
bool is_transpose = chunk_size % 2;
101+
chunk_size = chunk_size / 2;
102+
Value numElements = op.getNumElements();
103+
int mvin_type = extractConstantIntValue(numElements);
103104
unsigned SrcAddressSpace =
104105
*getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
105106
unsigned DstAddressSpace =
@@ -121,6 +122,10 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> {
121122
} else {
122123
return rewriter.notifyMatchFailure(op, "Unsupported DMA operation");
123124
}
125+
// rs1 = mvin_type << 62 | main memory address
126+
Value shift62 = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(62));
127+
rs1 = rewriter.create<LLVM::ShlOp>(loc, rewriter.getI64Type(), rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(mvin_type)), shift62);
128+
rs1 = rewriter.create<LLVM::OrOp>(loc, rs1, rewriter.create<LLVM::PtrToIntOp>(loc, rewriter.getI64Type(), SrcPtr));
124129
Value rows;
125130
Value cols;
126131
if (tile_shape.size() == 2) {
@@ -151,11 +156,14 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> {
151156
OpBuilder::InsertionGuard guard(rewriter);
152157
rewriter.setInsertionPointToStart(&outerBlock);
153158
// config_rs1 = main memory stride
154-
// config_rs2 = is_transpose << 32 | element size
155-
159+
// config_rs2 = chunk-size << 32 | mvin_type << 17 | is_transpose << 16 | element size
160+
156161
Value config_rs1 = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(main_mem_stride_val));
162+
Value config_shift16 = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(16));
157163
Value config_shift32 = rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(32));
158-
Value config_rs2 = rewriter.create<LLVM::ShlOp>(loc, rewriter.getI64Type(), rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(is_transpose)), config_shift32);
164+
Value config_rs2 = rewriter.create<LLVM::ShlOp>(loc, rewriter.getI64Type(), rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(chunk_size)), config_shift32);
165+
config_rs2 = rewriter.create<LLVM::OrOp>(loc, config_rs2, rewriter.create<LLVM::ShlOp>(loc, rewriter.getI64Type(), rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(mvin_type)), config_shift16));
166+
config_rs2 = rewriter.create<LLVM::OrOp>(loc, config_rs2, rewriter.create<LLVM::ShlOp>(loc, rewriter.getI64Type(), rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(is_transpose)), config_shift16));
159167
config_rs2 = rewriter.create<LLVM::OrOp>(loc, config_rs2, rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI64Type(), rewriter.getI64IntegerAttr(int(elen/8)))); // element size [bytes]
160168
rewriter.create<LLVM::InlineAsmOp>(
161169
loc,

0 commit comments

Comments
 (0)