Skip to content

[mlir][vector] Add more tests for ConvertVectorToLLVM (1/n) #101936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Aug 5, 2024

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:

  • vector.bitcast
  • vector.broadcast

Note, this has uncovered some missing logic in BroadcastOpLowering.
This PR fixes the most basic cases where the scalable flags were dropped
and the generated code was incorrect.

The BroadcastOpLowering pattern is effectively disabled for scalable
vectors in more complex cases where an SCF loop would be required to
loop over the scalable dims, e.g.:

 %0 = vector.broadcast %arg0 : vector<[4]x1x2xf32> to vector<[4]x3x2xf32>

These cases are marked as "Stetch not at start" in the code. In those
case, support for scalable vectors is left as a TODO.

Depends on #101928 - only review the top commit

@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes
  • [mlir][vector] Clarify the semantics of BroadcastOp
  • [mlir][vector] Add more tests for ConvertVectorToLLVM (1/n)

Patch is 27.91 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/101936.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+2)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+35-10)
  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp (+6-1)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+236)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+14)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index ac55433fadb2f..9f61f7c866d3d 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -68,9 +68,13 @@ enum class BroadcastableToResult {
   DimensionMismatch = 2,
   SourceTypeNotAVector = 3
 };
+struct VectorDim {
+  int64_t dim;
+  bool scalableFlag;
+};
 BroadcastableToResult
 isBroadcastableTo(Type srcType, VectorType dstVectorType,
-                  std::pair<int, int> *mismatchingDims = nullptr);
+                  std::pair<VectorDim, VectorDim> *mismatchingDims = nullptr);
 
 /// Collect a set of vector-to-vector canonicalization patterns.
 void populateVectorToVectorCanonicalizationPatterns(RewritePatternSet &patterns,
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 434ff3956c250..08bff3d5e1382 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -367,6 +367,8 @@ def Vector_BroadcastOp :
                                s_1     x .. x s_j x .. x s_k
                <duplication>         <potential stretch>
        ```
+    * a scalable unit dimeension, `[1]`, must match exactly.
+
     The source operand is duplicated over all the missing leading dimensions
     and stretched over the trailing dimensions where the source has a non-equal
     dimension of 1. These rules imply that any scalar broadcast (k=0) to any
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5047bd925d4c5..673c128932893 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2371,9 +2371,9 @@ Value BroadcastOp::createOrFoldBroadcastOp(
   return res;
 }
 
-BroadcastableToResult
-mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
-                                std::pair<int, int> *mismatchingDims) {
+BroadcastableToResult mlir::vector::isBroadcastableTo(
+    Type srcType, VectorType dstVectorType,
+    std::pair<VectorDim, VectorDim> *mismatchingDims) {
   // Broadcast scalar to vector of the same element type.
   if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
       getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
@@ -2391,12 +2391,28 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
   // (all leading dimensions are simply duplicated).
   int64_t lead = dstRank - srcRank;
   for (int64_t r = 0; r < srcRank; ++r) {
+    bool mismatch = false;
+
+    // Check fixed-width dims
     int64_t srcDim = srcVectorType.getDimSize(r);
     int64_t dstDim = dstVectorType.getDimSize(lead + r);
-    if (srcDim != 1 && srcDim != dstDim) {
+    if ((srcDim != 1 && srcDim != dstDim))
+      mismatch = true;
+
+    // Check scalable flags
+    bool srcDimScalableFlag = srcVectorType.getScalableDims()[r];
+    bool dstDimScalableFlag = dstVectorType.getScalableDims()[lead + r];
+    if ((srcDim == 1 && srcDimScalableFlag && dstDim != 1) ||
+        (srcDimScalableFlag && !dstDimScalableFlag))
+      mismatch = true;
+
+    if (mismatch) {
       if (mismatchingDims) {
-        mismatchingDims->first = srcDim;
-        mismatchingDims->second = dstDim;
+        mismatchingDims->first.dim = srcDim;
+        mismatchingDims->first.scalableFlag = srcDimScalableFlag;
+
+        mismatchingDims->second.dim = dstDim;
+        mismatchingDims->second.scalableFlag = dstDimScalableFlag;
       }
       return BroadcastableToResult::DimensionMismatch;
     }
@@ -2406,16 +2422,25 @@ mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
 }
 
 LogicalResult BroadcastOp::verify() {
-  std::pair<int, int> mismatchingDims;
+  std::pair<VectorDim, VectorDim> mismatchingDims;
   BroadcastableToResult res = isBroadcastableTo(
       getSourceType(), getResultVectorType(), &mismatchingDims);
   if (res == BroadcastableToResult::Success)
     return success();
   if (res == BroadcastableToResult::SourceRankHigher)
     return emitOpError("source rank higher than destination rank");
-  if (res == BroadcastableToResult::DimensionMismatch)
-    return emitOpError("dimension mismatch (")
-           << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
+  if (res == BroadcastableToResult::DimensionMismatch) {
+    std::string msg =
+        (Twine("dimension mismatch (") +
+         (mismatchingDims.first.scalableFlag ? "[" : "") +
+         std::to_string(mismatchingDims.first.dim) +
+         (mismatchingDims.first.scalableFlag ? "]" : "") + " vs. " +
+         (mismatchingDims.second.scalableFlag ? "[" : "") +
+         std::to_string(mismatchingDims.second.dim) +
+         (mismatchingDims.second.scalableFlag ? "]" : "") + ")")
+            .str();
+    return emitOpError(msg);
+  }
   if (res == BroadcastableToResult::SourceTypeNotAVector)
     return emitOpError("source type is not a vector");
   llvm_unreachable("unexpected vector.broadcast op error");
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index 32e7eb27f5e29..6c36bbaee8523 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -125,7 +125,8 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
     //   ..
     //   %x = [%a,%b,%c,%d]
     VectorType resType =
-        VectorType::get(dstType.getShape().drop_front(), eltType);
+        VectorType::get(dstType.getShape().drop_front(), eltType,
+                        dstType.getScalableDims().drop_front());
     Value result = rewriter.create<arith::ConstantOp>(
         loc, dstType, rewriter.getZeroAttr(dstType));
     if (m == 0) {
@@ -136,6 +137,10 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
         result = rewriter.create<vector::InsertOp>(loc, bcst, result, d);
     } else {
       // Stetch not at start.
+      if (dstType.getScalableDims()[0]) {
+        // TODO: For scalable vectors we should emit an scf.for loop.
+        return failure();
+      }
       for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) {
         Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), d);
         Value bcst = rewriter.create<vector::BroadcastOp>(loc, resType, ext);
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index c310954b906e4..40f5fd9baadab 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -23,6 +23,15 @@ func.func @bitcast_f32_to_i32_vector(%input: vector<16xf32>) -> vector<16xi32> {
 // CHECK-SAME:  %[[input:.*]]: vector<16xf32>
 // CHECK:       llvm.bitcast %[[input]] : vector<16xf32> to vector<16xi32>
 
+func.func @bitcast_f32_to_i32_vector_scalable(%input: vector<[16]xf32>) -> vector<[16]xi32> {
+  %0 = vector.bitcast %input : vector<[16]xf32> to vector<[16]xi32>
+  return %0 : vector<[16]xi32>
+}
+
+// CHECK-LABEL: @bitcast_f32_to_i32_scalable_vector
+// CHECK-SAME:  %[[input:.*]]: vector<[16]xf32>
+// CHECK:       llvm.bitcast %[[input]] : vector<[16]xf32> to vector<[16]xi32>
+
 // -----
 
 func.func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
@@ -34,6 +43,15 @@ func.func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> {
 // CHECK-SAME:  %[[input:.*]]: vector<64xi8>
 // CHECK:       llvm.bitcast %[[input]] : vector<64xi8> to vector<16xf32>
 
+func.func @bitcast_i8_to_f32_vector_scalable(%input: vector<[64]xi8>) -> vector<[16]xf32> {
+  %0 = vector.bitcast %input : vector<[64]xi8> to vector<[16]xf32>
+  return %0 : vector<[16]xf32>
+}
+
+// CHECK-LABEL: @bitcast_i8_to_f32_scalable_vector
+// CHECK-SAME:  %[[input:.*]]: vector<[64]xi8>
+// CHECK:       llvm.bitcast %[[input]] : vector<[64]xi8> to vector<[16]xf32>
+
 // -----
 
 func.func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> {
@@ -46,6 +64,16 @@ func.func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8
 // CHECK:       %[[T0:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<16xindex> to vector<16xi64>
 // CHECK:       llvm.bitcast %[[T0]] : vector<16xi64> to vector<128xi8>
 
+func.func @bitcast_index_to_i8_vector_scalable(%input: vector<[16]xindex>) -> vector<[128]xi8> {
+  %0 = vector.bitcast %input : vector<[16]xindex> to vector<[128]xi8>
+  return %0 : vector<[128]xi8>
+}
+
+// CHECK-LABEL: @bitcast_index_to_i8_scalable_vector
+// CHECK-SAME:  %[[input:.*]]: vector<[16]xindex>
+// CHECK:       %[[T0:.*]] = builtin.unrealized_conversion_cast %[[input]] : vector<[16]xindex> to vector<[16]xi64>
+// CHECK:       llvm.bitcast %[[T0]] : vector<[16]xi64> to vector<[128]xi8>
+
 // -----
 
 func.func @broadcast_vec0d_from_f32(%arg0: f32) -> vector<f32> {
@@ -80,6 +108,17 @@ func.func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> {
 // CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
 // CHECK:       return %[[T1]] : vector<2xf32>
 
+
+func.func @broadcast_vec1d_from_f32_scalable(%arg0: f32) -> vector<[2]xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+// CHECK-LABEL: @broadcast_scalable_vec1d_from_f32
+// CHECK-SAME:  %[[A:.*]]: f32)
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       return %[[T1]] : vector<[2]xf32>
+
 // -----
 
 func.func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
@@ -94,6 +133,18 @@ func.func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> {
 // CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<2xi64> to vector<2xindex>
 // CHECK:       return %[[T2]] : vector<2xindex>
 
+func.func @broadcast_vec1d_from_index_scalable(%arg0: index) -> vector<[2]xindex> {
+  %0 = vector.broadcast %arg0 : index to vector<[2]xindex>
+  return %0 : vector<[2]xindex>
+}
+// CHECK-LABEL: @broadcast_scalable_vec1d_from_index
+// CHECK-SAME:  %[[A:.*]]: index)
+// CHECK:       %[[A1:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A1]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<[2]xi64> to vector<[2]xindex>
+// CHECK:       return %[[T2]] : vector<[2]xindex>
+
 // -----
 
 func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
@@ -109,6 +160,19 @@ func.func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> {
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32>
 // CHECK:       return %[[T4]] : vector<2x3xf32>
 
+func.func @broadcast_vec2d_from_scalar_scalable(%arg0: f32) -> vector<2x[3]xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<2x[3]xf32>
+  return %0 : vector<2x[3]xf32>
+}
+// CHECK-LABEL: @broadcast_scalable_vec2d_from_scalar(
+// CHECK-SAME:  %[[A:.*]]: f32)
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1] : !llvm.array<2 x vector<[3]xf32>>
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x vector<[3]xf32>> to vector<2x[3]xf32>
+// CHECK:       return %[[T4]] : vector<2x[3]xf32>
+
 // -----
 
 func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
@@ -125,6 +189,21 @@ func.func @broadcast_vec3d_from_scalar(%arg0: f32) -> vector<2x3x4xf32> {
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<4xf32>>> to vector<2x3x4xf32>
 // CHECK:       return %[[T4]] : vector<2x3x4xf32>
 
+
+func.func @broadcast_vec3d_from_scalar_scalable(%arg0: f32) -> vector<2x3x[4]xf32> {
+  %0 = vector.broadcast %arg0 : f32 to vector<2x3x[4]xf32>
+  return %0 : vector<2x3x[4]xf32>
+}
+// CHECK-LABEL: @broadcast_scalable_vec3d_from_scalar(
+// CHECK-SAME:  %[[A:.*]]: f32)
+// CHECK:       %[[T0:.*]] = llvm.insertelement %[[A]]
+// CHECK:       %[[T1:.*]] = llvm.shufflevector %[[T0]]
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[0, 0] : !llvm.array<2 x array<3 x vector<[4]xf32>>>
+// ...
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %{{.*}}[1, 2] : !llvm.array<2 x array<3 x vector<[4]xf32>>>
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : !llvm.array<2 x array<3 x vector<[4]xf32>>> to vector<2x3x[4]xf32>
+// CHECK:       return %[[T4]] : vector<2x3x[4]xf32>
+
 // -----
 
 func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
@@ -135,6 +214,14 @@ func.func @broadcast_vec1d_from_vec1d(%arg0: vector<2xf32>) -> vector<2xf32> {
 // CHECK-SAME:  %[[A:.*]]: vector<2xf32>)
 // CHECK:       return %[[A]] : vector<2xf32>
 
+func.func @broadcast_vec1d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<[2]xf32>
+  return %0 : vector<[2]xf32>
+}
+// CHECK-LABEL: @broadcast_scalable_vec1d_from_vec1d(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
+// CHECK:       return %[[A]] : vector<[2]xf32>
+
 // -----
 
 func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
@@ -172,6 +259,20 @@ func.func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> {
 // CHECK:       %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<3 x vector<2xf32>> to vector<3x2xf32>
 // CHECK:       return %[[T5]] : vector<3x2xf32>
 
+func.func @broadcast_vec2d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<3x[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<3x[2]xf32>
+  return %0 : vector<3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_scalable_vec2d_from_vec1d(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
+// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T2:.*]] = llvm.insertvalue %[[A]], %[[T1]][0] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][1] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][2] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : !llvm.array<3 x vector<[2]xf32>> to vector<3x[2]xf32>
+// CHECK:       return %[[T5]] : vector<3x[2]xf32>
+
 // -----
 
 func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x2xindex> {
@@ -188,6 +289,20 @@ func.func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x
 // CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.array<3 x vector<2xi64>> to vector<3x2xindex>
 // CHECK:       return %[[T4]] : vector<3x2xindex>
 
+func.func @broadcast_vec2d_from_index_vec1d_scalable(%arg0: vector<[2]xindex>) -> vector<3x[2]xindex> {
+  %0 = vector.broadcast %arg0 : vector<[2]xindex> to vector<3x[2]xindex>
+  return %0 : vector<3x[2]xindex>
+}
+// CHECK-LABEL: @broadcast_scalable_vec2d_from_index_vec1d(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xindex>)
+// CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<[2]xindex> to vector<[2]xi64>
+// CHECK:       %[[T0:.*]] = arith.constant dense<0> : vector<3x[2]xindex>
+// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xindex> to !llvm.array<3 x vector<[2]xi64>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<[2]xi64>>
+
+// CHECK:       %[[T4:.*]] = builtin.unrealized_conversion_cast %{{.*}} : !llvm.array<3 x vector<[2]xi64>> to vector<3x[2]xindex>
+// CHECK:       return %[[T4]] : vector<3x[2]xindex>
+
 // -----
 
 func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> {
@@ -213,6 +328,29 @@ func.func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32>
 // CHECK:       %[[T11:.*]] = builtin.unrealized_conversion_cast %[[T10]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32>
 // CHECK:       return %[[T11]] : vector<4x3x2xf32>
 
+func.func @broadcast_vec3d_from_vec1d_scalable(%arg0: vector<[2]xf32>) -> vector<4x3x[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<[2]xf32> to vector<4x3x[2]xf32>
+  return %0 : vector<4x3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_scalable_vec3d_from_vec1d(
+// CHECK-SAME:  %[[A:.*]]: vector<[2]xf32>)
+// CHECK-DAG:   %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x[2]xf32>
+// CHECK-DAG:   %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK-DAG:   %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK-DAG:   %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
+
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T5:.*]] = llvm.insertvalue %[[A]], %[[T4]][2] : !llvm.array<3 x vector<[2]xf32>>
+
+// CHECK:       %[[T7:.*]] = llvm.insertvalue %[[T5]], %[[T6]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T8:.*]] = llvm.insertvalue %[[T5]], %[[T7]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T9:.*]] = llvm.insertvalue %[[T5]], %[[T8]][2] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T10:.*]] = llvm.insertvalue %[[T5]], %[[T9]][3] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+
+// CHECK:       %[[T11:.*]] = builtin.unrealized_conversion_cast %[[T10]] : !llvm.array<4 x array<3 x vector<[2]xf32>>> to vector<4x3x[2]xf32>
+// CHECK:       return %[[T11]] : vector<4x3x[2]xf32>
+
 // -----
 
 func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf32> {
@@ -231,6 +369,22 @@ func.func @broadcast_vec3d_from_vec2d(%arg0: vector<3x2xf32>) -> vector<4x3x2xf3
 // CHECK:       %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32>
 // CHECK:       return %[[T10]] : vector<4x3x2xf32>
 
+func.func @broadcast_vec3d_from_vec2d_scalable(%arg0: vector<3x[2]xf32>) -> vector<4x3x[2]xf32> {
+  %0 = vector.broadcast %arg0 : vector<3x[2]xf32> to vector<4x3x[2]xf32>
+  return %0 : vector<4x3x[2]xf32>
+}
+// CHECK-LABEL: @broadcast_scalable_vec3d_from_vec2d(
+// CHECK-SAME:  %[[A:.*]]: vector<3x[2]xf32>)
+// CHECK:       %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x[2]xf32> to !llvm.array<3 x vector<[2]xf32>>
+// CHECK:       %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x[2]xf32>
+// CHECK:       %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x[2]xf32> to !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T7:.*]] = llvm.insertvalue %[[T1]], %[[T5]][2] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T9:.*]] = llvm.insertvalue %[[T1]], %[[T7]][3] : !llvm.array<4 x array<3 x vector<[2]xf32>>>
+// CHECK:       %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<[2]xf32>>> to vector<4x3x[2]xf32>
+// CHECK:       return %[[T10]] : vector<4x3x[2]xf32>
+
 
 // -----
 
@@ -246,6 +400,18 @@ func.func @broadcast_stretch(%arg0: vector<1xf32>) -> vector<4xf32> {
 // CHECK:       %[[T4:.*]] = llvm.shufflevector %[[T3]]
 // CHECK:       return %[[T4]] : vector<4xf32>
 
+func.func @broadcast_stretch_scalable(%arg0: vector<1xf32>) -> vector<[4]xf32> {
+  %0 = vector.broadcast %arg0 : vector<1xf32> to vector<[4]...
[truncated]

@banach-space banach-space changed the title andrzej/extend vector to llvm test [mlir][vector] Add more tests for ConvertVectorToLLVM (1/n) Aug 5, 2024
@banach-space banach-space force-pushed the andrzej/extend_vector_to_llvm_test branch from 1fc8706 to f864c7a Compare August 5, 2024 09:15
Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to complain about.

Adds tests with scalable vectors for the Vector-To-LLVM conversion pass.
Covers the following Ops:
  * vector.bitcast
  * vector.broadcast

Note, this has uncovered some missing logic in `BroadcastOpLowering`.
This PR fixes the most basic cases where the scalable flags were dropped
and the generated code was incorrect.

The `BroadcastOpLowering` pattern is effectively disabled for scalable
vectors in more complex cases where an SCF loop would be required to
loop over the scalable dims, e.g.:
```mlir
 %0 = vector.broadcast %arg0 : vector<[4]x1x2xf32> to vector<[4]x3x2xf32>
```

These cases are marked as "Stetch not at start" in the code. In those
case, support for scalable vectors is left as a TODO.
@banach-space banach-space force-pushed the andrzej/extend_vector_to_llvm_test branch from f864c7a to 02faa5d Compare August 8, 2024 09:42
Refine BroadcastOp::verify. Specifically, relax it to allow:
```mlir
%0 = vector.broadcast %arg0 : vector<1xf32> to vector<[4]xf32>
```
@banach-space banach-space force-pushed the andrzej/extend_vector_to_llvm_test branch from 02faa5d to f2e1ec8 Compare August 8, 2024 10:03
@banach-space banach-space merged commit 22a1302 into llvm:main Aug 8, 2024
7 checks passed
@banach-space banach-space deleted the andrzej/extend_vector_to_llvm_test branch August 8, 2024 14:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants