Skip to content

[mlir][spirv] Add support for VectorAnyINTEL capability #68034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

mshahneo
Copy link
Contributor

@mshahneo mshahneo commented Oct 2, 2023

Allow vector of any lengths between [2-2^32-1].
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.

@llvmbot
Copy link
Member

llvmbot commented Oct 2, 2023

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Changes

Allow vector of any lengths between [2-2^32-1].
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.


Patch is 44.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68034.diff

16 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td (+8-3)
  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+70)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp (+5-2)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+16-6)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+102-24)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+2-2)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+40)
  • (modified) mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir (+15-2)
  • (modified) mlir/test/Conversion/GPUToSPIRV/reductions.mlir (+32-32)
  • (modified) mlir/test/Dialect/SPIRV/IR/bit-ops.mlir (+3-3)
  • (modified) mlir/test/Dialect/SPIRV/IR/gl-ops.mlir (+1-1)
  • (modified) mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir (+2-2)
  • (modified) mlir/test/Dialect/SPIRV/IR/logical-ops.mlir (+1-1)
  • (modified) mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir (+22-20)
  • (modified) mlir/test/Target/SPIRV/arithmetic-ops.mlir (+3-3)
  • (modified) mlir/test/Target/SPIRV/ocl-ops.mlir (+6)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..c458a500eb367f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
 def SPIRV_Float32 : TypeAlias<F32, "Float32">;
 def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
 def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
+// Remove the vector size restriction.
+// Although the vector size can be upto (2^64-1), uint64,
+// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
+// for all practical cases.
+// Also unsigned is used for the number elements for composite tyeps.
+def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
                                        [SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
 // Component type check is done in the type parser for the following SPIR-V
 // dialect-specific types so we use "Any" here.
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
     "Joint Matrix">;
 
 class SPIRV_ScalarOrVectorOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+    AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;
 
 class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
-    AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+    AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
                SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
 
 class SPIRV_MatrixOrCoopMatrixOf<Type type> :
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..703122547df7493 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
   ScalableVectorOfLength<allowedLengths>.summary,
   "::mlir::VectorType">;
 
+// Whether the number of elements of a vector is from the given
+// `allowedRanges` list, the list has two values, start and end
+// of the range (inclusive).
+class IsVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsFixedVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsFixedVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsScalableVectorOfLengthRangePred<list<int> allowedRanges>
+  : And<[IsScalableVectorTypePred,
+        And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+            CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list.
+class VectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list.
+class FixedVectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsFixedVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list.
+class ScalableVectorOfLengthRange<list<int> allowedRanges>
+  : Type<IsScalableVectorOfLengthRangePred<allowedRanges>,
+    " of length " # !interleave(allowedRanges, "-"),
+    "::mlir::VectorType">;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>,
+        VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary,
+        "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<
+      And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>,
+      FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary,
+      "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class ScalableVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+  : Type<
+      And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfLengthRange<allowedRanges>.predicate]>,
+      ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfLengthRange<allowedRanges>.summary,
+      "::mlir::VectorType">;
+
+
 def AnyVector : VectorOf<[AnyType]>;
 // Temporary vector type clone that allows gradual transition to 0-D vectors.
 def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..be85d3c330a887a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
       return Type();
     }
-    if (t.getNumElements() > 4) {
+    // Number of elements should be between [2 - 2^32 -1],
+    // since getNumElements() returns an unsigned, the upper limit check is
+    // unnecessary.
+    if (t.getNumElements() < 2) {
       parser.emitError(
-          typeLoc, "vector length has to be less than or equal to 4 but found ")
+          typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
           << t.getNumElements();
       return Type();
     }
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f965d..9d39d99b4148253 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
 }
 
 bool CompositeType::isValid(VectorType type) {
-  return type.getRank() == 1 &&
-         llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
-         llvm::isa<ScalarType>(type.getElementType());
+  // Number of elements should be between [2 - 2^32 -1],
+  // since getNumElements() returns an unsigned, the upper limit check is
+  // unnecessary.
+  return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
+         type.getNumElements() >= 2;
 }
 
 Type CompositeType::getElementType(unsigned index) const {
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
       .Case<VectorType>([&](VectorType type) {
         auto vecSize = getNumElements();
         if (vecSize == 8 || vecSize == 16) {
-          static const Capability caps[] = {Capability::Vector16};
-          ArrayRef<Capability> ref(caps, std::size(caps));
-          capabilities.push_back(ref);
+          static constexpr Capability caps[] = {Capability::Vector16,
+                                                Capability::VectorAnyINTEL};
+          capabilities.push_back(caps);
+        }
+        // VectorAnyINTEL capability removes the vector size restriction and
+        // allows the vector size to be up to (2^32-1).
+        // Vector16 capability allows the vector size to be 8 and 16
+        SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
+        if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
+          static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
+          capabilities.push_back(caps);
         }
         return llvm::cast<ScalarType>(type.getElementType())
             .getCapabilities(capabilities, storage);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..25e6a080642e681 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
 template <typename LabelT>
 static LogicalResult checkExtensionRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+    const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+    const ArrayRef<spirv::Extension> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
+          return llvm::is_contained(ors, elidedExt);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
 template <typename LabelT>
 static LogicalResult checkCapabilityRequirements(
     LabelT label, const spirv::TargetEnv &targetEnv,
-    const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+    const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+    const ArrayRef<spirv::Capability> elidedCandidates = {}) {
   for (const auto &ors : candidates) {
-    if (targetEnv.allows(ors))
+    if (targetEnv.allows(ors) ||
+        llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
+          return llvm::is_contained(ors, elidedCap);
+        }))
       continue;
 
     LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
   return success();
 }
 
-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements
+/// Checks that `capCandidates`, `extCandidates`, and capability
+/// (`capCandidates`) infered extension requirements are possible to be
+/// satisfied with the given `targetEnv`.
+/// It also provides a way to relax requirements for certain capabilities and
+/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
+/// allow passes to relax certain requirements based on an option (e.g.,
+/// relaxing bitwidth requirement, see `convertScalarType()`,
+/// `ConvertVectorType()`).
+template <typename LabelT>
+static LogicalResult checkCapabilityAndExtensionRequirements(
+    LabelT label, const spirv::TargetEnv &targetEnv,
+    const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
+    const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
+    const ArrayRef<spirv::Capability> elidedCapCandidates = {},
+    const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
+  SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
+  llvm::append_range(updatedExtCandidates, extCandidates);
+
+  if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
+                                         elidedCapCandidates)))
+    return failure();
+  // Add capablity infered extensions to the list of extension requirement list,
+  // only considers the capabilities that already available in the `targetEnv`.
+
+  // WARNING: Some capabilities are part of both the core SPIR-V
+  // specification and an extension (e.g., 'Groups' capability is part of both
+  // core specification and SPV_AMD_shader_ballot extension, hence we should
+  // relax the capability inferred extension for these cases).
+  static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
+  ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
+                                                     std::size(multiModalCaps));
+
+  for (auto cap : targetEnv.getAttr().getCapabilities()) {
+    if (llvm::any_of(multiModalCapsArrayRef,
+                     [&cap](spirv::Capability mMCap) { return cap == mMCap; }))
+      continue;
+    std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
+    if (ext)
+      updatedExtCandidates.push_back(*ext);
+  }
+  if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
+                                        elidedExtCandidates)))
+    return failure();
+  return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
   switch (storageClass) {
   case spirv::StorageClass::PhysicalStorageBuffer:
@@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
   type.getCapabilities(capabilities, storageClass);
 
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions)))
     return type;
 
   // Otherwise we need to adjust the type, which really means adjusting the
@@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
   cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
 
+  // If the bit-width related capabilities and extensions are not met
+  // for lower bit-width (<32-bit), convert it to 32-bit
+  auto elementType =
+      convertScalarType(targetEnv, options, scalarType, storageClass);
+  if (!elementType)
+    return nullptr;
+  type = VectorType::get(type.getShape(), elementType);
+
+  SmallVector<spirv::Capability, 4> elidedCaps;
+  SmallVector<spirv::Extension, 4> elidedExts;
+
+  // Relax the bitwidth requirements for capabilities and extensions
+  if (options.emulateLT32BitScalarTypes) {
+    elidedCaps.push_back(spirv::Capability::Int8);
+    elidedCaps.push_back(spirv::Capability::Int16);
+    elidedCaps.push_back(spirv::Capability::Float16);
+  }
+  // For capabilities whose requirements were relaxed, relax requirements for
+  // the extensions that were infered by those capabilities (e.g., elidedCaps)
+  for (spirv::Capability cap : elidedCaps) {
+    std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
+    if (ext)
+      llvm::append_range(elidedExts, *ext);
+  }
   // If all requirements are met, then we can accept this type as-is.
-  if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
-      succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+  if (succeeded(checkCapabilityAndExtensionRequirements(
+          type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
     return type;
 
-  auto elementType =
-      convertScalarType(targetEnv, options, scalarType, storageClass);
-  if (elementType)
-    return VectorType::get(type.getShape(), elementType);
   return nullptr;
 }
 
@@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
   SmallVector<ArrayRef<spirv::Capability>, 2> caps;
   scalarType.getExtensions(exts);
   scalarType.getCapabilities(caps);
-  if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
-      failed(checkExtensionRequirements(type, targetEnv, exts))) {
+
+  if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
+                                                     exts))) {
     auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
     return castOp.getResult(0);
   }
@@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
   for (Type valueType : valueTypes) {
-    typeExtensions.clear();
-    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
-    if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
-                                          typeExtensions)))
-      return false;
-
     typeCapabilities.clear();
     cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
-    if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
-                                           typeCapabilities)))
+    typeExtensions.clear();
+    cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
+    // Checking for capability and extension requirements along with capability
+    // infered extensions.
+    // If a capability is present, the extension that
+    // supports it should also be present, this reduces the burden of adding
+    // extension requirement that may or maynot be added in
+    // CompositeType::getExtensions().
+    if (failed(checkCapabilityAndExtensionRequirements(
+            op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
       return false;
   }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d8570..d61ace8d6876b87 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -11,9 +11,9 @@ module attributes {
     #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
 } {
 
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
   // expected-error@+1 {{failed to legalize operation 'arith.subi'}}
-  %1 = arith.subi %arg0, %arg0: vector<5xi32>
+  %1 = arith.subi %arg0, %arg1: vector<5xi32>
   return
 }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397d..6ceeade486efd68 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
 }
 
 } // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// VectorAnyINTEL support
+//===----------------------------------------------------------------------===//
+
+// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
+// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @any_vector
+func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
+  // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
+  %0 = arith.subi %arg0, %arg1: vector<16xi32>
+  return
+}
+
+// CHECK-LABEL: @max_vector
+func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
+  // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
+  %0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
+  return
+}
+
+
+// Check float vector types of any size.
+// CHECK-LABEL: @float_vector58
+func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
+  // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
+  %0 = arith.addf %arg0, %arg0: vector<5xf16>
+  // CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
+  %1 = arith.mulf %a...
[truncated]

@mshahneo
Copy link
Contributor Author

mshahneo commented Oct 2, 2023

Please use the last commit for review. Added other commits as dependency. These commits are opened as separate PRs (#68030, #68031, #68033)

@mshahneo mshahneo removed mlir:core MLIR Core Infrastructure mlir:gpu mlir:ods labels Oct 2, 2023
@mshahneo mshahneo force-pushed the PR-vector-any-intel-all-commits branch from 47e14c0 to f1b971b Compare October 2, 2023 22:12
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:ods labels Oct 2, 2023
@mshahneo mshahneo force-pushed the PR-vector-any-intel-all-commits branch from f1b971b to 22cd31a Compare October 4, 2023 23:33
… a range of values

Add types and predicates for Vector, Fixed Vector, and Scalable Vector
whose number of elements is from a given `allowedRanges` list.
The list has two values, start and end of the range (inclusive).
Allow a way to relax requirements for certain capabilities and
extensions (e.g., `elidedCandidates`).

Also add a combined check for capabilities and extensions in
`checkCapabilityAndExtensionRequirements`.
This function checks capabilities, extensions, and
capability infered extension requirements.
…ility requirements

Replace the seperate extension and capability checking with combined check
`checkCapabilityAndExtensionRequirements()`. This makes the code flow simpler.
Also adds the extra check for capability inferred extension check.

Need for capability inferred extension check:
If a capability is a requirement, the respective extension that implements
it should also become an extension requirement, there were no support for
that check, as a result, the extension requirement had to be added separately.
This separate requirement addition causes problem when a feature is enabled by
multiple capability, and one of the capability is part of an extension. E.g.,
vector size of 16 can be enabled by both "Vector16" and "vectorAnyINTEL"
capability, however, only "vectorAnyINTEL" has an extension requirement
("SPV_INTEL_vector_compute"). Since the process of adding capability
and extension requirement are independent, there is no way, to handle
cases like this. Therefore, for cases like this, enable adding capability
requirement initially, then do the check for capability inferred extension.
Allow vector of any lengths between [2-2^32-1].
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension)
relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.
@mshahneo mshahneo force-pushed the PR-vector-any-intel-all-commits branch from 22cd31a to 1041658 Compare October 12, 2023 17:04
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that relaxing the vector length requirements necessitates a new module-level check that the types are compatible with the declared capabilities. Do you have any opinions on this, @antiagainst?

// Any scalable vector where the minimum number of elements is from the given
// `allowedRanges` list and the type is from the given `allowedTypes`
// list.
class ScalableVectorOfMinLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we exclude any predicated not used by the SPIR-V dialect?

parser.emitError(
typeLoc, "vector length has to be less than or equal to 4 but found ")
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I find this format a bit difficult to parse with two - signs. What do you think about something like this:

Suggested change
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
typeLoc, "vector length must be in the range [2, 2^32), but found ")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

return type.getRank() == 1 &&
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
llvm::isa<ScalarType>(type.getElementType());
// Number of elements should be between [2 - 2^32 -1],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

Comment on lines +115 to +116
const ArrayRef<spirv::Capability> elidedCapCandidates = {},
const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In mlir, we generally don't use const for arguments passed by value

Comment on lines +420 to +421
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
if (ext)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
if (ext)
if (std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap))

Comment on lines 29 to 32
func.func @exp(%arg0 : vector<5xf32>) -> () {
// expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
// CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32
%2 = spirv.GL.Exp %arg0 : vector<5xf32>
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes the intention behind this test. Originally it checked that the type constraints are present; could we use some other type to exercise this?

Comment on lines -21 to +31
func.func @exp(%arg0 : i32) -> () {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%2 = spirv.CL.exp %arg0 : i32
func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () {
// CHECK: spirv.CL.exp {{%.*}} : vector<5xf32>
%2 = spirv.CL.exp %arg0 : vector<5xf32>
return
}

// -----

func.func @exp(%arg0 : vector<5xf32>) -> () {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}}
%2 = spirv.CL.exp %arg0 : vector<5xf32>
func.func @exp(%arg0 : i32) -> () {
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
%2 = spirv.CL.exp %arg0 : i32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here and in the other test below

}
}

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: missing newline

@@ -2,7 +2,7 @@

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>>
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], [SPV_KHR_uniform_group_instructions]>, #spirv.resource_limits<>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this capability added everywhere in this test? If this is necessary to pass validation, we should land it separately from this PR.

Copy link
Member

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! I'm okay of extending the SPIR-V dialect to support VectorAnyINTEL. I see actually you've putting quite some efforts to make sure details are correct; really appreciate that! Though we still need to be quite careful here given this is touching some very fundamental expectations and validations. So please bear with me for maybe multiple rounds of reviews. :)

I left a few comments but not going through stuff extensively yet. Can we actually sequence this as multiple patches for easy reviews? E.g.,

  1. Relax the IR vector type requirements, but as @kuar said, move the validation to spirv.module op, where we know the list of final capablities. We can then go through all ops and see their types and validate.
  2. Then do the conversion side changes.

@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
// Remove the vector size restriction.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "Remove the vector size restriction." sentense does not need to be in the comment--it's the goal of this patch; but not a proper doc of SPIRV_Vector.

@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
// Remove the vector size restriction.
// Although the vector size can be upto (2^64-1), uint64,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this actually work out for sizes > uint32 range? In SPIR-V the OpTypeVector's component count is spec'ed to be a unsigned 32-bit integer.. Did the Intel spec somehow change the definition there? Could you point me to the spec?

// Although the vector size can be upto (2^64-1), uint64,
// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
// for all practical cases.
// Also unsigned is used for the number elements for composite tyeps.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: types

// `allowedRanges` list.
class VectorOfLengthRange<list<int> allowedRanges>
: Type<IsVectorOfLengthRangePred<allowedRanges>,
" of length " # !interleave(allowedRanges, "-"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use to here instead of -? Using - reads like minus to me..

parser.emitError(
typeLoc, "vector length has to be less than or equal to 4 but found ")
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Capability::VectorAnyINTEL};
capabilities.push_back(caps);
}
// VectorAnyINTEL capability removes the vector size restriction and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow the logic here. We are pushing duplicated capabilities here? Shouldn't we do the check vecSize > 4 to include VectorAnyINTEL first, and then additionally add Vector16 if vecSize == 8 || vecSize == 16?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants