Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Implement emulation of static indexing subbyte type vector stores #115922

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

lialan
Copy link
Contributor

@lialan lialan commented Nov 12, 2024

This patch enables unaligned, statically indexed storing of vectors with sub emulation width element types. By default, this patch ensures atomicity by using atomic read-modify-write sequence in places where data contention might happen. The caller function is able to avoid emitting atomic operations by setting useAtomicWrites to false when calling populateVectorNarrowTypeEmulationPatterns. In such case, a regular rmw sequence is emitted for efficiency.

To illustrate the mechanism, consider the example of storing vector<7xi2> into memref<3x7xi2>[1, 0]. In this case the linearized indices of those bits being overwritten are [14, 28), which are:

  • the last 2 bits of byte no.2
  • byte no.3
  • first 4 bits of byte no.4

This patch expands the store into two (maybe atomic) rmw partial stores (for the first and the third byte), and a non-atomic single full byte store (for the second byte).

@llvmbot
Copy link

llvmbot commented Nov 12, 2024

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: lialan (lialan)

Changes

This patch enables unaligned, statically indexed storing of vectors with subbyte element types.

To ensure atomicity, boundary bytes in the front and at the back are being treated separately with atomic stores to avoid thread contention, while those bytes (if any) free from race conditions will use non-atomic stores.

To illustrate the mechanism, consider the example of storing vector&lt;7xi2&gt; into memref&lt;3x7xi2&gt;[1, 0]. In this case the linearized indices of those bits being overwritten are [14, 28), which are:

  • the last 2 bits of byte no.2
  • byte no.3
  • first 4 bits of byte no.4

This patch expands the store into two atomic partial stores (for the first and the third byte), and a non-atomic single full byte store (for the second byte).


Patch is 21.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115922.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+166-33)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir (+135)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 7578aadee23a6e..031eb2d9c443b8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -143,19 +143,19 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
 /// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
 /// emitting `vector.extract_strided_slice`.
 static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
-                                        VectorType extractType, Value source,
-                                        int64_t frontOffset,
+                                        Value source, int64_t frontOffset,
                                         int64_t subvecSize) {
-  auto vectorType = cast<VectorType>(source.getType());
-  assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
-         "expected 1-D source and destination types");
-  (void)vectorType;
+  auto vectorType = llvm::cast<VectorType>(source.getType());
+  assert(vectorType.getRank() == 1 && "expected 1-D source types");
   auto offsets = rewriter.getI64ArrayAttr({frontOffset});
   auto sizes = rewriter.getI64ArrayAttr({subvecSize});
   auto strides = rewriter.getI64ArrayAttr({1});
+
+  auto resultVectorType =
+      VectorType::get({subvecSize}, vectorType.getElementType());
   return rewriter
-      .create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
-                                             sizes, strides)
+      .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
+                                             offsets, sizes, strides)
       ->getResult(0);
 }
 
@@ -164,12 +164,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
 /// `vector.insert_strided_slice`.
 static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
                                        Value src, Value dest, int64_t offset) {
-  auto srcType = cast<VectorType>(src.getType());
-  auto destType = cast<VectorType>(dest.getType());
+  [[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
+  [[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
   assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
          "expected source and dest to be vector type");
-  (void)srcType;
-  (void)destType;
   auto offsets = rewriter.getI64ArrayAttr({offset});
   auto strides = rewriter.getI64ArrayAttr({1});
   return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
@@ -236,6 +234,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
       newLoad);
 }
 
+static void nonAtomicStore(ConversionPatternRewriter &rewriter, Location loc,
+                           Value memref, Value index, Value value) {
+  auto originType = dyn_cast<VectorType>(value.getType());
+  auto memrefElemType = dyn_cast<MemRefType>(memref.getType()).getElementType();
+  auto scale = memrefElemType.getIntOrFloatBitWidth() /
+               originType.getElementType().getIntOrFloatBitWidth();
+  auto storeType =
+      VectorType::get({originType.getNumElements() / scale}, memrefElemType);
+  auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType, value);
+  rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memref, index);
+}
+
+/// atomically store a subbyte-sized value to memory, with a mask.
+static Value atomicStore(OpBuilder &rewriter, Location loc,
+                         Value emulatedMemref, Value emulatedIndex,
+                         TypedValue<VectorType> value, Value mask,
+                         int64_t scale) {
+  auto atomicOp = rewriter.create<memref::GenericAtomicRMWOp>(
+      loc, emulatedMemref, ValueRange{emulatedIndex});
+  OpBuilder builder =
+      OpBuilder::atBlockEnd(atomicOp.getBody(), rewriter.getListener());
+  Value origValue = atomicOp.getCurrentValue();
+
+  // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
+  auto oneVectorType = VectorType::get({1}, origValue.getType());
+  auto fromElem = builder.create<vector::FromElementsOp>(loc, oneVectorType,
+                                                         ValueRange{origValue});
+  auto vectorBitCast =
+      builder.create<vector::BitCastOp>(loc, value.getType(), fromElem);
+
+  auto select =
+      builder.create<arith::SelectOp>(loc, mask, value, vectorBitCast);
+  auto bitcast2 = builder.create<vector::BitCastOp>(loc, oneVectorType, select);
+  auto extract = builder.create<vector::ExtractOp>(loc, bitcast2, 0);
+  builder.create<memref::AtomicYieldOp>(loc, extract.getResult());
+  return atomicOp;
+}
+
+// Extract a slice of a vector, and insert it into a byte vector.
+static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
+                                  Location loc, TypedValue<VectorType> vector,
+                                  int64_t sliceOffset, int64_t sliceNumElements,
+                                  int64_t byteOffset) {
+  auto vectorElementType = vector.getType().getElementType();
+  assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
+         "vector element must be a valid sub-byte type");
+  auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
+  auto emptyByteVector = rewriter.create<arith::ConstantOp>(
+      loc, VectorType::get({scale}, vectorElementType),
+      rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
+  auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
+                                              sliceOffset, sliceNumElements);
+  auto inserted = staticallyInsertSubvector(rewriter, loc, extracted,
+                                            emptyByteVector, byteOffset);
+  return inserted;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -251,7 +306,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
 
     auto loc = op.getLoc();
     auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
-    Type oldElementType = op.getValueToStore().getType().getElementType();
+    auto valueToStore = op.getValueToStore();
+    Type oldElementType = valueToStore.getType().getElementType();
     Type newElementType = convertedType.getElementType();
     int srcBits = oldElementType.getIntOrFloatBitWidth();
     int dstBits = newElementType.getIntOrFloatBitWidth();
@@ -275,15 +331,15 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
     // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
     // vector<4xi8>
 
-    auto origElements = op.getValueToStore().getType().getNumElements();
-    if (origElements % scale != 0)
-      return failure();
+    auto origElements = valueToStore.getType().getNumElements();
+    bool isUnalignedEmulation = origElements % scale != 0;
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
 
     OpFoldResult linearizedIndices;
-    std::tie(std::ignore, linearizedIndices) =
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
         memref::getLinearizedMemRefOffsetAndSize(
             rewriter, loc, srcBits, dstBits,
             stridedMetadata.getConstifiedMixedOffset(),
@@ -291,14 +347,94 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
             stridedMetadata.getConstifiedMixedStrides(),
             getAsOpFoldResult(adaptor.getIndices()));
 
-    auto numElements = origElements / scale;
-    auto bitCast = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements, newElementType),
-        op.getValueToStore());
+    auto foldedIntraVectorOffset =
+        isUnalignedEmulation
+            ? getConstantIntValue(linearizedInfo.intraDataOffset)
+            : 0;
+
+    if (!foldedIntraVectorOffset) {
+      // unimplemented case for dynamic front padding size
+      return failure();
+    }
+
+    Value emulatedMemref = adaptor.getBase();
+    // the index into the target memref we are storing to
+    Value currentDestIndex =
+        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
+    auto constantOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto atomicMaskType = VectorType::get({scale}, rewriter.getI1Type());
+    // the index into the source vector we are currently processing
+    auto currentSourceIndex = 0;
+
+    // 1. atomic store for the first byte
+    auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
+    if (frontAtomicStoreElem != 0) {
+      auto frontMaskValues = llvm::SmallVector<bool>(scale, false);
+      if (*foldedIntraVectorOffset + origElements < scale) {
+        std::fill_n(frontMaskValues.begin() + *foldedIntraVectorOffset,
+                    origElements, true);
+        frontAtomicStoreElem = origElements;
+      } else {
+        std::fill_n(frontMaskValues.end() - frontAtomicStoreElem,
+                    *foldedIntraVectorOffset, true);
+      }
+      auto frontMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(atomicMaskType, frontMaskValues));
+
+      currentSourceIndex = scale - (*foldedIntraVectorOffset);
+      auto value = extractSliceIntoByte(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0,
+          frontAtomicStoreElem, *foldedIntraVectorOffset);
+
+      atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                  cast<TypedValue<VectorType>>(value), frontMask.getResult(),
+                  scale);
+
+      currentDestIndex = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), currentDestIndex, constantOne);
+    }
+
+    if (currentSourceIndex >= origElements) {
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    // 2. non-atomic store
+    int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
+    int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
+    if (nonAtomicStoreSize != 0) {
+      auto nonAtomicStorePart = staticallyExtractSubvector(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+          currentSourceIndex, numNonAtomicElements);
+
+      nonAtomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                     nonAtomicStorePart);
+
+      currentSourceIndex += numNonAtomicElements;
+      currentDestIndex = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), currentDestIndex,
+          rewriter.create<arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
+    }
+
+    // 3. atomic store for the last byte
+    auto remainingElements = origElements - currentSourceIndex;
+    if (remainingElements != 0) {
+      auto atomicStorePart = extractSliceIntoByte(
+          rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
+          currentSourceIndex, remainingElements, 0);
+
+      // back mask
+      auto maskValues = llvm::SmallVector<bool>(scale, 0);
+      std::fill_n(maskValues.begin(), remainingElements, 1);
+      auto backMask = rewriter.create<arith::ConstantOp>(
+          loc, DenseElementsAttr::get(atomicMaskType, maskValues));
+
+      atomicStore(rewriter, loc, emulatedMemref, currentDestIndex,
+                  cast<TypedValue<VectorType>>(atomicStorePart),
+                  backMask.getResult(), scale);
+    }
 
-    rewriter.replaceOpWithNewOp<vector::StoreOp>(
-        op, bitCast.getResult(), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+    rewriter.eraseOp(op);
     return success();
   }
 };
@@ -496,9 +632,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
           linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
     return success();
@@ -652,9 +787,8 @@ struct ConvertVectorMaskedLoad final
           rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
           op.getPassThru(), linearizedInfo.intraDataOffset, origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
@@ -732,9 +866,8 @@ struct ConvertVectorTransferRead final
                                            linearizedInfo.intraDataOffset,
                                            origElements);
     } else if (isUnalignedEmulation) {
-      result =
-          staticallyExtractSubvector(rewriter, loc, op.getType(), result,
-                                     *foldedIntraVectorOffset, origElements);
+      result = staticallyExtractSubvector(
+          rewriter, loc, result, *foldedIntraVectorOffset, origElements);
     }
     rewriter.replaceOp(op, result);
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
index 7ed75ff7f1579c..3b1f5b2e160fe0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned.mlir
@@ -249,3 +249,138 @@ func.func @vector_maskedload_i2_dynamic_indexing_mixed(%passthru: vector<3xi2>,
 // CHECK: %[[IN8:.+]] = vector.insert %[[EX8]], %[[IN7]] [1] : i2 into vector<3xi2>
 // CHECK: %[[EX9:.+]] = vector.extract %[[SELECT]][%[[INCIDX2]]] : i2 from vector<8xi2>
 // CHECK: %[[IN9:.+]] = vector.insert %[[EX9]], %[[IN8]] [2] : i2 into vector<3xi2>
+
+// -----
+
+func.func @vector_store_i2_const(%arg0: vector<3xi2>) {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores, with the first storing 1 element and the second storing 2 elements.
+// CHECK: func @vector_store_i2_const(
+// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+
+// atomic store of the first byte
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<3xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// atomic store of the second byte
+// CHECK: %[[ADDI:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDI]]] : memref<3xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST3]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST4]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT3]] : i8
+
+// -----
+
+func.func @vector_store_i8_2(%arg0: vector<7xi2>) {
+    %0 = memref.alloc() : memref<3x7xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
+    return
+}
+
+// in this example, emit 2 atomic stores and 1 non-atomic store
+
+// CHECK: func @vector_store_i8_2(
+// CHECK-SAME: %[[ARG0:.+]]: vector<7xi2>)
+// CHECK: %[[ALLOC]] = memref.alloc() : memref<6xi8>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]> : vector<4xi1>
+// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
+
+// first atomic store
+// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]} : vector<7xi2> to vector<1xi2>
+// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
+// CHECK-SAME: {offsets = [3], strides = [1]} : vector<1xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[C1]]] : memref<6xi8> {
+// CHECK: %[[ARG:.+]]: i8):
+// CHECK: %[[FROM_ELEM:.+]] = vector.from_elements %[[ARG]] : vector<1xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[FROM_ELEM]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[BITCAST]] : vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST2:.+]] = vector.bitcast %[[SELECT]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST2]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT2]] : i8
+
+// non atomic store part
+// CHECK: %[[ADDR:.+]] = arith.addi %[[C1]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]} : vector<7xi2> to vector<4xi2>
+// CHECK: %[[BITCAST3:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: vector.store %[[BITCAST3]], %[[ALLOC]][%[[ADDR]]] : memref<6xi8>, vector<1xi8>
+
+// second atomic store
+// CHECK: %[[ADDR2:.+]] = arith.addi %[[ADDR]], %[[C1]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[ARG0]]
+// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]} : vector<7xi2> to vector<2xi2>
+// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT3]], %[[CST0]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xi2> into vector<4xi2>
+// CHECK: %[[ATOMIC_RMW2:.+]] = memref.generic_atomic_rmw %[[ALLOC]][%[[ADDR2]]] : memref<6xi8> {
+// CHECK: %[[ARG2:.+]]: i8):
+// CHECK: %[[FROM_ELEM2:.+]] = vector.from_elements %[[ARG2]] : vector<1xi8>
+// CHECK: %[[BITCAST4:.+]] = vector.bitcast %[[FROM_ELEM2]] : vector<1xi8> to vector<4xi2>
+// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[BITCAST4]] :
+// CHECK-SAME: vector<4xi1>, vector<4xi2>
+// CHECK: %[[BITCAST5:.+]] = vector.bitcast %[[SELECT2]] : vector<4xi2> to vector<1xi8>
+// CHECK: %[[EXTRACT4:.+]] = vector.extract %[[BITCAST5]][0] : i8 from vector<1xi8>
+// CHECK: memref.atomic_yield %[[EXTRACT4]] : i8    
+
+// -----
+
+func.func @vector_store_i2_single_atomic(%arg0: vector<1xi2>) {
+    %0 = memref.alloc() : memref<4x1xi2>
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
+    return
+}
+
+// in this example, only emit 1 ato...
[truncated]

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if the start of the store is aligned? Would it still generate atomic stores?

I havent looked into the details yet, but we only need an atomic store if we cannot guarantee that there are no competing stores. This information we canno thave at this level, but a caller might have this information. I think it might be better to allow for a caller to indicate that the atomic stores are not needed. The default can be that you do the atomic stores.

@lialan lialan force-pushed the lialan/atomic_stores branch 2 times, most recently from bfdbed2 to edfe3d4 Compare November 12, 2024 22:18
@lialan
Copy link
Contributor Author

lialan commented Nov 12, 2024

What would happen if the start of the store is aligned? Would it still generate atomic stores?

It will be operation as usual, the atomic and the extra handling only happen when needed.

I havent looked into the details yet, but we only need an atomic store if we cannot guarantee that there are no competing stores. This information we canno thave at this level, but a caller might have this information. I think it might be better to allow for a caller to indicate that the atomic stores are not needed. The default can be that you do the atomic stores.

Feel like in such a non-competing case we can do strength reduction and use masked store instead of atomic for the unaligned parts (the beginning and the end byte if it is unaligned).

@MaheshRavishankar
Copy link
Contributor

Feel like in such a non-competing case we can do strength reduction and use masked store instead of atomic for the unaligned parts (the beginning and the end byte if it is unaligned).

You cant always tell from the program analysis that the atomic isnt needed. This depends on the caller's context. For example if the caller knows that, even dynamically, the stores never overlap. It is hard to recover that information at this late stage. So I think it would be good to have an option where atomics are avoided (cause it is expensive) and let caller opt-in/out?

@lialan lialan marked this pull request as ready for review November 14, 2024 20:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants