@@ -95,11 +95,12 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> {
9595 Value DstPtr = getStridedElementPtr (loc, dstMemRefType, DstMemref, dst_indices, rewriter);
9696 Value main_mem_stride = op.getStride ();
9797 int64_t main_mem_stride_val = extractConstantIntValue (main_mem_stride) * elen / 8 ;
98- Value num_elt = op.getNumElementsPerStride ();
99- int num_elt_val = extractConstantIntValue (num_elt);
100- int is_transpose = 0 ;
101- if (num_elt_val == 1 ) // num_elt = 1 means transposed
102- is_transpose = 1 ;
98+ Value num_elt_per_stride = op.getNumElementsPerStride ();
99+ int chunk_size = extractConstantIntValue (num_elt_per_stride);
100+ bool is_transpose = chunk_size % 2 ;
101+ chunk_size = chunk_size / 2 ;
102+ Value numElements = op.getNumElements ();
103+ int mvin_type = extractConstantIntValue (numElements);
103104 unsigned SrcAddressSpace =
104105 *getTypeConverter ()->getMemRefAddressSpace (srcMemRefType);
105106 unsigned DstAddressSpace =
@@ -121,6 +122,10 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> {
121122 } else {
122123 return rewriter.notifyMatchFailure (op, " Unsupported DMA operation" );
123124 }
125+ // rs1 = mvin_type << 62 | main memory address
126+ Value shift62 = rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (62 ));
127+ rs1 = rewriter.create <LLVM::ShlOp>(loc, rewriter.getI64Type (), rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (mvin_type)), shift62);
128+ rs1 = rewriter.create <LLVM::OrOp>(loc, rs1, rewriter.create <LLVM::PtrToIntOp>(loc, rewriter.getI64Type (), SrcPtr));
124129 Value rows;
125130 Value cols;
126131 if (tile_shape.size () == 2 ) {
@@ -151,11 +156,14 @@ struct DmaStartOpLowering : public ConvertOpToLLVMPattern<memref::DmaStartOp> {
151156 OpBuilder::InsertionGuard guard (rewriter);
152157 rewriter.setInsertionPointToStart (&outerBlock);
153158 // config_rs1 = main memory stride
154- // config_rs2 = is_transpose << 32 | element size
155-
159+ // config_rs2 = chunk-size << 32 | mvin_type << 17 | is_transpose << 16 | element size
160+
156161 Value config_rs1 = rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (main_mem_stride_val));
162+ Value config_shift16 = rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (16 ));
157163 Value config_shift32 = rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (32 ));
158- Value config_rs2 = rewriter.create <LLVM::ShlOp>(loc, rewriter.getI64Type (), rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (is_transpose)), config_shift32);
164+ Value config_rs2 = rewriter.create <LLVM::ShlOp>(loc, rewriter.getI64Type (), rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (chunk_size)), config_shift32);
165+ config_rs2 = rewriter.create <LLVM::OrOp>(loc, config_rs2, rewriter.create <LLVM::ShlOp>(loc, rewriter.getI64Type (), rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (mvin_type)), config_shift16));
166+ config_rs2 = rewriter.create <LLVM::OrOp>(loc, config_rs2, rewriter.create <LLVM::ShlOp>(loc, rewriter.getI64Type (), rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (is_transpose)), config_shift16));
159167 config_rs2 = rewriter.create <LLVM::OrOp>(loc, config_rs2, rewriter.create <LLVM::ConstantOp>(loc, rewriter.getI64Type (), rewriter.getI64IntegerAttr (int (elen/8 )))); // element size [bytes]
160168 rewriter.create <LLVM::InlineAsmOp>(
161169 loc,
0 commit comments