-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Make CooperativeMatrixType
a ShapedType
#142784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is to enable `CooperativeMatrixType` to be used with `DenseElementsAttr`, so that a `spirv.Constant` can be easily built from `OpConstantComposite`. For example: ```mlir %cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc> ``` Additional constraints are added to arithmetic operations, as `SameOperandsAndResultType` can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match. This patch does not enable the actual deserialization.
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Igor Wodiany (IgWod-IMG) ChangesThis is to enable %cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc> Additional constraints are added to arithmetic operations, as This patch does not enable the actual deserialization. This is done in #142786. Full diff: https://github.com/llvm/llvm-project/pull/142784.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..48f525e048e60 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
+ "cooperative matrix types match",
+ CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
+ "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
+ "? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
+>;
+
class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
list<Trait> traits = []> :
// Operands type same as result type.
SPIRV_BinaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])> {
+ [Pure, SameOperandsAndResultType,
+ SPIRV_SameCoopMatrix<"operand1", "operand2">,
+ SPIRV_SameCoopMatrix<"operand2", "result">])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
@@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
// Operand type same as result type.
SPIRV_UnaryOp<mnemonic, type, type,
!listconcat(traits,
- [Pure, SameOperandsAndResultType])> {
+ [Pure, SameOperandsAndResultType,
+ SPIRV_SameCoopMatrix<"operand", "result">])> {
// In addition to normal types arithmetic instructions can support cooperative
// matrix.
let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2e29e9afaabf4..a7b6569245dd5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
// SPIR-V KHR cooperative matrix type
class CooperativeMatrixType
: public Type::TypeBase<CooperativeMatrixType, CompositeType,
- detail::CooperativeMatrixTypeStorage> {
+ detail::CooperativeMatrixTypeStorage,
+ ShapedType::Trait> {
public:
using Base::Base;
@@ -418,6 +419,23 @@ class CooperativeMatrixType
std::optional<StorageClass> storage = std::nullopt);
void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
std::optional<StorageClass> storage = std::nullopt);
+
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ ArrayRef<int64_t> getShape() const;
+
+ bool hasRank() const { return true; }
+
+ CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ if (shape == std::nullopt)
+ return get(elementType, getRows(), getColumns(), getScope(), getUse());
+ else {
+ assert(shape.value().size() == 2);
+ return get(elementType, shape.value()[0], shape.value()[1], getScope(),
+ getUse());
+ }
+ }
};
// SPIR-V matrix type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..de2034680cd5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -195,7 +195,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
using KeyTy =
- std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+ std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
static CooperativeMatrixTypeStorage *
construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
}
bool operator==(const KeyTy &key) const {
- return key == KeyTy(elementType, rows, columns, scope, use);
+ return key == KeyTy(elementType, shape[0], shape[1], scope, use);
}
CooperativeMatrixTypeStorage(const KeyTy &key)
- : elementType(std::get<0>(key)), rows(std::get<1>(key)),
- columns(std::get<2>(key)), scope(std::get<3>(key)),
+ : elementType(std::get<0>(key)),
+ shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
use(std::get<4>(key)) {}
Type elementType;
- uint32_t rows;
- uint32_t columns;
+ // [#rows, #columns]
+ SmallVector<int64_t, 2> shape;
Scope scope;
CooperativeMatrixUseKHR use;
};
@@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const {
return getImpl()->elementType;
}
-uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+uint32_t CooperativeMatrixType::getRows() const {
+ return static_cast<uint32_t>(getImpl()->shape[0]);
+}
uint32_t CooperativeMatrixType::getColumns() const {
- return getImpl()->columns;
+ return static_cast<uint32_t>(getImpl()->shape[1]);
+}
+
+ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
+ return getImpl()->shape;
}
Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..4ae8b70bf43ca 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
%q = "spirv.IAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
-> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
%b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
- // expected-error @+1 {{op requires the same type for all operands and results}}
+ // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
%q = "spirv.FAdd"(%a, %b) :
(!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
-> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>
|
This is to enable
CooperativeMatrixType
to be used withDenseElementsAttr
, so that aspirv.Constant
can be easily built fromOpConstantComposite
. For example:Additional constraints are added to arithmetic operations, as
SameOperandsAndResultType
can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match.This patch does not enable the actual deserialization. This is done in #142786.