-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[mlir][spirv] Add support for VectorAnyINTEL capability #68034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-core ChangesAllow vector of any lengths between [2-2^32-1]. Patch is 44.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/68034.diff 16 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index 1013cbc8ca562b7..c458a500eb367f9 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">;
def SPIRV_Float32 : TypeAlias<F32, "Float32">;
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>;
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>;
-def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16],
+// Remove the vector size restriction.
+// Although the vector size can be upto (2^64-1), uint64,
+// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose
+// for all practical cases.
+// Also unsigned is used for the number elements for composite tyeps.
+def SPIRV_Vector : VectorOfLengthRangeAndType<[2, 0xFFFFFFFF],
[SPIRV_Bool, SPIRV_Integer, SPIRV_Float]>;
// Component type check is done in the type parser for the following SPIR-V
// dialect-specific types so we use "Any" here.
@@ -4206,10 +4211,10 @@ class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
"Joint Matrix">;
class SPIRV_ScalarOrVectorOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>]>;
+ AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>]>;
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
- AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>,
+ AnyTypeOf<[type, VectorOfLengthRangeAndType<[2, 0xFFFFFFFF], [type]>,
SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index 4fc14e30b8a10d0..703122547df7493 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -546,6 +546,76 @@ class ScalableVectorOfRankAndLengthAndType<list<int> allowedRanks,
ScalableVectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
+// Whether the number of elements of a vector is from the given
+// `allowedRanges` list, the list has two values, start and end
+// of the range (inclusive).
+class IsVectorOfLengthRangePred<list<int> allowedRanges>
+ : And<[IsVectorTypePred,
+ And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()>= }] # allowedRanges[0]>,
+ CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a fixed-length vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsFixedVectorOfLengthRangePred<list<int> allowedRanges>
+ : And<[IsFixedVectorTypePred,
+ And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+ CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Whether the number of elements of a scalable vector is from the given
+// `allowedRanges` list, the list has two values, start and end of the range (inclusive).
+class IsScalableVectorOfLengthRangePred<list<int> allowedRanges>
+ : And<[IsScalableVectorTypePred,
+ And<[CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() >= }] # allowedRanges[0]>,
+ CPred<[{$_self.cast<::mlir::VectorType>().getNumElements() <= }] # allowedRanges[1]>]>]>;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list.
+class VectorOfLengthRange<list<int> allowedRanges>
+ : Type<IsVectorOfLengthRangePred<allowedRanges>,
+ " of length " # !interleave(allowedRanges, "-"),
+ "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list.
+class FixedVectorOfLengthRange<list<int> allowedRanges>
+ : Type<IsFixedVectorOfLengthRangePred<allowedRanges>,
+ " of length " # !interleave(allowedRanges, "-"),
+ "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list.
+class ScalableVectorOfLengthRange<list<int> allowedRanges>
+ : Type<IsScalableVectorOfLengthRangePred<allowedRanges>,
+ " of length " # !interleave(allowedRanges, "-"),
+ "::mlir::VectorType">;
+
+// Any vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class VectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+ : Type<And<[VectorOf<allowedTypes>.predicate, VectorOfLengthRange<allowedRanges>.predicate]>,
+ VectorOf<allowedTypes>.summary # VectorOfLengthRange<allowedRanges>.summary,
+ "::mlir::VectorType">;
+
+// Any fixed-length vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class FixedVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+ : Type<
+ And<[FixedVectorOf<allowedTypes>.predicate, FixedVectorOfLengthRange<allowedRanges>.predicate]>,
+ FixedVectorOf<allowedTypes>.summary # FixedVectorOfLengthRange<allowedRanges>.summary,
+ "::mlir::VectorType">;
+
+// Any scalable vector where the number of elements is from the given
+// `allowedRanges` list and the type is from the given `allowedTypes`
+// list.
+class ScalableVectorOfLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes>
+ : Type<
+ And<[ScalableVectorOf<allowedTypes>.predicate, ScalableVectorOfLengthRange<allowedRanges>.predicate]>,
+ ScalableVectorOf<allowedTypes>.summary # ScalableVectorOfLengthRange<allowedRanges>.summary,
+ "::mlir::VectorType">;
+
+
def AnyVector : VectorOf<[AnyType]>;
// Temporary vector type clone that allows gradual transition to 0-D vectors.
def AnyVectorOfAnyRank : VectorOfAnyRankOf<[AnyType]>;
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a51d77dda78bf2f..be85d3c330a887a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -184,9 +184,12 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
return Type();
}
- if (t.getNumElements() > 4) {
+ // Number of elements should be between [2 - 2^32 -1],
+ // since getNumElements() returns an unsigned, the upper limit check is
+ // unnecessary.
+ if (t.getNumElements() < 2) {
parser.emitError(
- typeLoc, "vector length has to be less than or equal to 4 but found ")
+ typeLoc, "vector length has to be between [2 - 2^32 -1] but found ")
<< t.getNumElements();
return Type();
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 39d6603a46f965d..9d39d99b4148253 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -101,9 +101,11 @@ bool CompositeType::classof(Type type) {
}
bool CompositeType::isValid(VectorType type) {
- return type.getRank() == 1 &&
- llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) &&
- llvm::isa<ScalarType>(type.getElementType());
+ // Number of elements should be between [2 - 2^32 -1],
+ // since getNumElements() returns an unsigned, the upper limit check is
+ // unnecessary.
+ return type.getRank() == 1 && llvm::isa<ScalarType>(type.getElementType()) &&
+ type.getNumElements() >= 2;
}
Type CompositeType::getElementType(unsigned index) const {
@@ -171,9 +173,17 @@ void CompositeType::getCapabilities(
.Case<VectorType>([&](VectorType type) {
auto vecSize = getNumElements();
if (vecSize == 8 || vecSize == 16) {
- static const Capability caps[] = {Capability::Vector16};
- ArrayRef<Capability> ref(caps, std::size(caps));
- capabilities.push_back(ref);
+ static constexpr Capability caps[] = {Capability::Vector16,
+ Capability::VectorAnyINTEL};
+ capabilities.push_back(caps);
+ }
+ // VectorAnyINTEL capability removes the vector size restriction and
+ // allows the vector size to be up to (2^32-1).
+ // Vector16 capability allows the vector size to be 8 and 16
+ SmallVector<unsigned, 5> allowedVecRange = {2, 3, 4, 8, 16};
+ if (vecSize >= 2 && !llvm::is_contained(allowedVecRange, vecSize)) {
+ static constexpr Capability caps[] = {Capability::VectorAnyINTEL};
+ capabilities.push_back(caps);
}
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c75d217663a9e09..25e6a080642e681 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -43,9 +43,13 @@ using namespace mlir;
template <typename LabelT>
static LogicalResult checkExtensionRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
- const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
+ const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
+ const ArrayRef<spirv::Extension> elidedCandidates = {}) {
for (const auto &ors : candidates) {
- if (targetEnv.allows(ors))
+ if (targetEnv.allows(ors) ||
+ llvm::any_of(elidedCandidates, [&ors](spirv::Extension elidedExt) {
+ return llvm::is_contained(ors, elidedExt);
+ }))
continue;
LLVM_DEBUG({
@@ -71,9 +75,13 @@ static LogicalResult checkExtensionRequirements(
template <typename LabelT>
static LogicalResult checkCapabilityRequirements(
LabelT label, const spirv::TargetEnv &targetEnv,
- const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
+ const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
+ const ArrayRef<spirv::Capability> elidedCandidates = {}) {
for (const auto &ors : candidates) {
- if (targetEnv.allows(ors))
+ if (targetEnv.allows(ors) ||
+ llvm::any_of(elidedCandidates, [&ors](spirv::Capability elidedCap) {
+ return llvm::is_contained(ors, elidedCap);
+ }))
continue;
LLVM_DEBUG({
@@ -90,8 +98,55 @@ static LogicalResult checkCapabilityRequirements(
return success();
}
-/// Returns true if the given `storageClass` needs explicit layout when used in
-/// Shader environments.
+/// Check capabilities and extensions requirements
+/// Checks that `capCandidates`, `extCandidates`, and capability
+/// (`capCandidates`) infered extension requirements are possible to be
+/// satisfied with the given `targetEnv`.
+/// It also provides a way to relax requirements for certain capabilities and
+/// extensions (e.g., `elidedCapCandidates`, `elidedExtCandidates`), this is to
+/// allow passes to relax certain requirements based on an option (e.g.,
+/// relaxing bitwidth requirement, see `convertScalarType()`,
+/// `ConvertVectorType()`).
+template <typename LabelT>
+static LogicalResult checkCapabilityAndExtensionRequirements(
+ LabelT label, const spirv::TargetEnv &targetEnv,
+ const spirv::SPIRVType::CapabilityArrayRefVector &capCandidates,
+ const spirv::SPIRVType::ExtensionArrayRefVector &extCandidates,
+ const ArrayRef<spirv::Capability> elidedCapCandidates = {},
+ const ArrayRef<spirv::Extension> elidedExtCandidates = {}) {
+ SmallVector<ArrayRef<spirv::Extension>, 8> updatedExtCandidates;
+ llvm::append_range(updatedExtCandidates, extCandidates);
+
+ if (failed(checkCapabilityRequirements(label, targetEnv, capCandidates,
+ elidedCapCandidates)))
+ return failure();
+ // Add capablity infered extensions to the list of extension requirement list,
+ // only considers the capabilities that already available in the `targetEnv`.
+
+ // WARNING: Some capabilities are part of both the core SPIR-V
+ // specification and an extension (e.g., 'Groups' capability is part of both
+ // core specification and SPV_AMD_shader_ballot extension, hence we should
+ // relax the capability inferred extension for these cases).
+ static const spirv::Capability multiModalCaps[] = {spirv::Capability::Groups};
+ ArrayRef<spirv::Capability> multiModalCapsArrayRef(multiModalCaps,
+ std::size(multiModalCaps));
+
+ for (auto cap : targetEnv.getAttr().getCapabilities()) {
+ if (llvm::any_of(multiModalCapsArrayRef,
+ [&cap](spirv::Capability mMCap) { return cap == mMCap; }))
+ continue;
+ std::optional<ArrayRef<spirv::Extension>> ext = getExtensions(cap);
+ if (ext)
+ updatedExtCandidates.push_back(*ext);
+ }
+ if (failed(checkExtensionRequirements(label, targetEnv, updatedExtCandidates,
+ elidedExtCandidates)))
+ return failure();
+ return success();
+}
+
+/// Returns true if the given `storageClass` needs explicit layout when used
+/// in Shader environments.
static bool needsExplicitLayout(spirv::StorageClass storageClass) {
switch (storageClass) {
case spirv::StorageClass::PhysicalStorageBuffer:
@@ -230,8 +285,8 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
type.getCapabilities(capabilities, storageClass);
// If all requirements are met, then we can accept this type as-is.
- if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
- succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+ if (succeeded(checkCapabilityAndExtensionRequirements(
+ type, targetEnv, capabilities, extensions)))
return type;
// Otherwise we need to adjust the type, which really means adjusting the
@@ -342,15 +397,35 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
cast<spirv::CompositeType>(type).getExtensions(extensions, storageClass);
cast<spirv::CompositeType>(type).getCapabilities(capabilities, storageClass);
+ // If the bit-width related capabilities and extensions are not met
+ // for lower bit-width (<32-bit), convert it to 32-bit
+ auto elementType =
+ convertScalarType(targetEnv, options, scalarType, storageClass);
+ if (!elementType)
+ return nullptr;
+ type = VectorType::get(type.getShape(), elementType);
+
+ SmallVector<spirv::Capability, 4> elidedCaps;
+ SmallVector<spirv::Extension, 4> elidedExts;
+
+ // Relax the bitwidth requirements for capabilities and extensions
+ if (options.emulateLT32BitScalarTypes) {
+ elidedCaps.push_back(spirv::Capability::Int8);
+ elidedCaps.push_back(spirv::Capability::Int16);
+ elidedCaps.push_back(spirv::Capability::Float16);
+ }
+ // For capabilities whose requirements were relaxed, relax requirements for
+ // the extensions that were infered by those capabilities (e.g., elidedCaps)
+ for (spirv::Capability cap : elidedCaps) {
+ std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap);
+ if (ext)
+ llvm::append_range(elidedExts, *ext);
+ }
// If all requirements are met, then we can accept this type as-is.
- if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
- succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
+ if (succeeded(checkCapabilityAndExtensionRequirements(
+ type, targetEnv, capabilities, extensions, elidedCaps, elidedExts)))
return type;
- auto elementType =
- convertScalarType(targetEnv, options, scalarType, storageClass);
- if (elementType)
- return VectorType::get(type.getShape(), elementType);
return nullptr;
}
@@ -656,8 +731,9 @@ std::optional<Value> castToSourceType(const spirv::TargetEnv &targetEnv,
SmallVector<ArrayRef<spirv::Capability>, 2> caps;
scalarType.getExtensions(exts);
scalarType.getCapabilities(caps);
- if (failed(checkCapabilityRequirements(type, targetEnv, caps)) ||
- failed(checkExtensionRequirements(type, targetEnv, exts))) {
+
+ if (failed(checkCapabilityAndExtensionRequirements(type, targetEnv, caps,
+ exts))) {
auto castOp = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
return castOp.getResult(0);
}
@@ -1150,16 +1226,18 @@ bool SPIRVConversionTarget::isLegalOp(Operation *op) {
SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
for (Type valueType : valueTypes) {
- typeExtensions.clear();
- cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
- if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
- typeExtensions)))
- return false;
-
typeCapabilities.clear();
cast<spirv::SPIRVType>(valueType).getCapabilities(typeCapabilities);
- if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
- typeCapabilities)))
+ typeExtensions.clear();
+ cast<spirv::SPIRVType>(valueType).getExtensions(typeExtensions);
+ // Checking for capability and extension requirements along with capability
+ // infered extensions.
+ // If a capability is present, the extension that
+ // supports it should also be present, this reduces the burden of adding
+ // extension requirement that may or maynot be added in
+ // CompositeType::getExtensions().
+ if (failed(checkCapabilityAndExtensionRequirements(
+ op->getName(), this->targetEnv, typeCapabilities, typeExtensions)))
return false;
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d8570..d61ace8d6876b87 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -11,9 +11,9 @@ module attributes {
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
} {
-func.func @unsupported_5elem_vector(%arg0: vector<5xi32>) {
+func.func @unsupported_5elem_vector(%arg0: vector<5xi32>, %arg1: vector<5xi32>) {
// expected-error@+1 {{failed to legalize operation 'arith.subi'}}
- %1 = arith.subi %arg0, %arg0: vector<5xi32>
+ %1 = arith.subi %arg0, %arg1: vector<5xi32>
return
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 0221e4815a9397d..6ceeade486efd68 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1407,3 +1407,43 @@ func.func @float_scalar(%arg0: f16) {
}
} // end module
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// VectorAnyINTEL support
+//===----------------------------------------------------------------------===//
+
+// Check that with VectorAnyINTEL, VectorComputeINTEL capability,
+// and SPV_INTEL_vector_compute extension, any sized (2-2^32 -1) vector is allowed.
+module attributes {
+ spirv.target_env = #spirv.target_env<
+ #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Kernel, VectorAnyINTEL], [SPV_INTEL_vector_compute]>, #spirv.resource_limits<>>
+} {
+
+// CHECK-LABEL: @any_vector
+func.func @any_vector(%arg0: vector<16xi32>, %arg1: vector<16xi32>) {
+ // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<16xi32>
+ %0 = arith.subi %arg0, %arg1: vector<16xi32>
+ return
+}
+
+// CHECK-LABEL: @max_vector
+func.func @max_vector(%arg0: vector<4294967295xi32>, %arg1: vector<4294967295xi32>) {
+ // CHECK: spirv.ISub %{{.+}}, %{{.+}}: vector<4294967295xi32>
+ %0 = arith.subi %arg0, %arg1: vector<4294967295xi32>
+ return
+}
+
+
+// Check float vector types of any size.
+// CHECK-LABEL: @float_vector58
+func.func @float_vector58(%arg0: vector<5xf16>, %arg1: vector<8xf64>) {
+ // CHECK: spirv.FAdd %{{.*}}, %{{.*}}: vector<5xf16>
+ %0 = arith.addf %arg0, %arg0: vector<5xf16>
+ // CHECK: spirv.FMul %{{.*}}, %{{.*}}: vector<8xf64>
+ %1 = arith.mulf %a...
[truncated]
|
47e14c0
to
f1b971b
Compare
f1b971b
to
22cd31a
Compare
… a range of values Add types and predicates for Vector, Fixed Vector, and Scalable Vector whose number of elements is from a given `allowedRanges` list. The list has two values, start and end of the range (inclusive).
Allow a way to relax requirements for certain capabilities and extensions (e.g., `elidedCandidates`). Also add a combined check for capabilities and extensions in `checkCapabilityAndExtensionRequirements`. This function checks capabilities, extensions, and capability infered extension requirements.
…ility requirements Replace the seperate extension and capability checking with combined check `checkCapabilityAndExtensionRequirements()`. This makes the code flow simpler. Also adds the extra check for capability inferred extension check. Need for capability inferred extension check: If a capability is a requirement, the respective extension that implements it should also become an extension requirement, there were no support for that check, as a result, the extension requirement had to be added separately. This separate requirement addition causes problem when a feature is enabled by multiple capability, and one of the capability is part of an extension. E.g., vector size of 16 can be enabled by both "Vector16" and "vectorAnyINTEL" capability, however, only "vectorAnyINTEL" has an extension requirement ("SPV_INTEL_vector_compute"). Since the process of adding capability and extension requirement are independent, there is no way, to handle cases like this. Therefore, for cases like this, enable adding capability requirement initially, then do the check for capability inferred extension.
Allow vector of any lengths between [2-2^32-1]. VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.
22cd31a
to
1041658
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that relaxing the vector length requirements necessitates a new module-level check that the types are compatible with the declared capabilities. Do you have any opinions on this, @antiagainst?
// Any scalable vector where the minimum number of elements is from the given | ||
// `allowedRanges` list and the type is from the given `allowedTypes` | ||
// list. | ||
class ScalableVectorOfMinLengthRangeAndType<list<int> allowedRanges, list<Type> allowedTypes> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we exclude any predicated not used by the SPIR-V dialect?
parser.emitError( | ||
typeLoc, "vector length has to be less than or equal to 4 but found ") | ||
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I find this format a bit difficult to parse with two -
signs. What do you think about something like this:
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ") | |
typeLoc, "vector length must be in the range [2, 2^32), but found ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
return type.getRank() == 1 && | ||
llvm::is_contained({2, 3, 4, 8, 16}, type.getNumElements()) && | ||
llvm::isa<ScalarType>(type.getElementType()); | ||
// Number of elements should be between [2 - 2^32 -1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here.
const ArrayRef<spirv::Capability> elidedCapCandidates = {}, | ||
const ArrayRef<spirv::Extension> elidedExtCandidates = {}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In mlir, we generally don't use const
for arguments passed by value
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap); | ||
if (ext) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap); | |
if (ext) | |
if (std::optional<ArrayRef<spirv::Extension>> ext = spirv::getExtensions(cap)) |
func.func @exp(%arg0 : vector<5xf32>) -> () { | ||
// expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}} | ||
// CHECK: spirv.GL.Exp {{%.*}} : vector<5xf32 | ||
%2 = spirv.GL.Exp %arg0 : vector<5xf32> | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This changes the intention behind this test. Originally it checked that the type constraints are present; could we use some other type to exercise this?
func.func @exp(%arg0 : i32) -> () { | ||
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} | ||
%2 = spirv.CL.exp %arg0 : i32 | ||
func.func @exp_any_vec(%arg0 : vector<5xf32>) -> () { | ||
// CHECK: spirv.CL.exp {{%.*}} : vector<5xf32> | ||
%2 = spirv.CL.exp %arg0 : vector<5xf32> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
func.func @exp(%arg0 : vector<5xf32>) -> () { | ||
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4}} | ||
%2 = spirv.CL.exp %arg0 : vector<5xf32> | ||
func.func @exp(%arg0 : i32) -> () { | ||
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}} | ||
%2 = spirv.CL.exp %arg0 : i32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here and in the other test below
} | ||
} | ||
|
||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: missing newline
@@ -2,7 +2,7 @@ | |||
|
|||
module attributes { | |||
gpu.container_module, | |||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], []>, #spirv.resource_limits<>> | |||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupNonUniformArithmetic, GroupUniformArithmeticKHR], [SPV_KHR_uniform_group_instructions]>, #spirv.resource_limits<>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this capability added everywhere in this test? If this is necessary to pass validation, we should land it separately from this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! I'm okay of extending the SPIR-V dialect to support VectorAnyINTEL. I see actually you've putting quite some efforts to make sure details are correct; really appreciate that! Though we still need to be quite careful here given this is touching some very fundamental expectations and validations. So please bear with me for maybe multiple rounds of reviews. :)
I left a few comments but not going through stuff extensively yet. Can we actually sequence this as multiple patches for easy reviews? E.g.,
- Relax the IR vector type requirements, but as @kuar said, move the validation to
spirv.module
op, where we know the list of final capablities. We can then go through all ops and see their types and validate. - Then do the conversion side changes.
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">; | |||
def SPIRV_Float32 : TypeAlias<F32, "Float32">; | |||
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; | |||
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; | |||
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], | |||
// Remove the vector size restriction. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This "Remove the vector size restriction." sentense does not need to be in the comment--it's the goal of this patch; but not a proper doc of SPIRV_Vector
.
@@ -4146,7 +4146,12 @@ def SPIRV_Int32 : TypeAlias<I32, "Int32">; | |||
def SPIRV_Float32 : TypeAlias<F32, "Float32">; | |||
def SPIRV_Float : FloatOfWidths<[16, 32, 64]>; | |||
def SPIRV_Float16or32 : FloatOfWidths<[16, 32]>; | |||
def SPIRV_Vector : VectorOfLengthAndType<[2, 3, 4, 8, 16], | |||
// Remove the vector size restriction. | |||
// Although the vector size can be upto (2^64-1), uint64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this actually work out for sizes > uint32 range? In SPIR-V the OpTypeVector
's component count is spec'ed to be a unsigned 32-bit integer.. Did the Intel spec somehow change the definition there? Could you point me to the spec?
// Although the vector size can be upto (2^64-1), uint64, | ||
// 2^32-1 (UNINT32_MAX>) is a more realistic number, it should serve the purpose | ||
// for all practical cases. | ||
// Also unsigned is used for the number elements for composite tyeps. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: types
// `allowedRanges` list. | ||
class VectorOfLengthRange<list<int> allowedRanges> | ||
: Type<IsVectorOfLengthRangePred<allowedRanges>, | ||
" of length " # !interleave(allowedRanges, "-"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use to
here instead of -
? Using -
reads like minus to me..
parser.emitError( | ||
typeLoc, "vector length has to be less than or equal to 4 but found ") | ||
typeLoc, "vector length has to be between [2 - 2^32 -1] but found ") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Capability::VectorAnyINTEL}; | ||
capabilities.push_back(caps); | ||
} | ||
// VectorAnyINTEL capability removes the vector size restriction and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I follow the logic here. We are pushing duplicated capabilities here? Shouldn't we do the check vecSize > 4
to include VectorAnyINTEL
first, and then additionally add Vector16
if vecSize == 8 || vecSize == 16
?
Allow vector of any lengths between [2-2^32-1].
VectorAnyINTEL capability (part of "SPV_INTEL_vector_compute" extension) relaxes the length constraint on SPIR-V vector sizes from 2,3, and 4.