Skip to content

[mlir][spirv] Make CooperativeMatrixType a ShapedType #142784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

IgWod-IMG
Copy link
Contributor

@IgWod-IMG IgWod-IMG commented Jun 4, 2025

This is to enable CooperativeMatrixType to be used with DenseElementsAttr, so that a spirv.Constant can be easily built from OpConstantComposite. For example:

%cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc>

Additional constraints are added to arithmetic operations, as SameOperandsAndResultType can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match.

This patch does not enable the actual deserialization. This is done in #142786.

This is to enable `CooperativeMatrixType` to be used with
`DenseElementsAttr`, so that a `spirv.Constant` can be easily
built from `OpConstantComposite`. For example:

```mlir
%cst = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<1x1xf32, Subgroup, MatrixAcc>
```

Additional constraints are added to arithmetic operations, as
`SameOperandsAndResultType` can no longer fully verify CoopMatrices.
This is because for shaped types the verifier only checks
element type and shapes, whereas for any other arbitrary type it
looks for an exact match.

This patch does not enable the actual deserialization.
@llvmbot
Copy link
Member

llvmbot commented Jun 4, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

Changes

This is to enable CooperativeMatrixType to be used with DenseElementsAttr, so that a spirv.Constant can be easily built from OpConstantComposite. For example:

%cst = spirv.Constant dense&lt;0.000000e+00&gt; : !spirv.coopmatrix&lt;1x1xf32, Subgroup, MatrixAcc&gt;

Additional constraints are added to arithmetic operations, as SameOperandsAndResultType can no longer fully verify CoopMatrices. This is because for shaped types the verifier only checks element type and shapes, whereas for any other arbitrary type it looks for an exact match.

This patch does not enable the actual deserialization. This is done in #142786.


Full diff: https://github.com/llvm/llvm-project/pull/142784.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td (+12-2)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h (+19-1)
  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp (+14-8)
  • (modified) mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
index 22d5afcd77381..48f525e048e60 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td
@@ -18,12 +18,21 @@ include "mlir/Dialect/SPIRV/IR/SPIRVBase.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 
+class SPIRV_SameCoopMatrix<string lhs, string rhs> : PredOpTrait<
+  "cooperative matrix types match",
+  CPred<"(::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # lhs # ".getType()) "
+        "&& ::llvm::isa<::mlir::spirv::CooperativeMatrixType>($" # rhs # ".getType()))"
+        "? $" # lhs # ".getType() == $" # rhs # ".getType() : true">
+>;
+
 class SPIRV_ArithmeticBinaryOp<string mnemonic, Type type,
                                list<Trait> traits = []> :
       // Operands type same as result type.
       SPIRV_BinaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType])> {
+                               [Pure, SameOperandsAndResultType,
+                                SPIRV_SameCoopMatrix<"operand1", "operand2">,
+                                SPIRV_SameCoopMatrix<"operand2", "result">])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
@@ -42,7 +51,8 @@ class SPIRV_ArithmeticUnaryOp<string mnemonic, Type type,
       // Operand type same as result type.
       SPIRV_UnaryOp<mnemonic, type, type,
                    !listconcat(traits,
-                               [Pure, SameOperandsAndResultType])> {
+                               [Pure, SameOperandsAndResultType,
+                                SPIRV_SameCoopMatrix<"operand", "result">])> {
   // In addition to normal types arithmetic instructions can support cooperative
   // matrix.
   let arguments = (ins
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 2e29e9afaabf4..a7b6569245dd5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo);
 // SPIR-V KHR cooperative matrix type
 class CooperativeMatrixType
     : public Type::TypeBase<CooperativeMatrixType, CompositeType,
-                            detail::CooperativeMatrixTypeStorage> {
+                            detail::CooperativeMatrixTypeStorage,
+                            ShapedType::Trait> {
 public:
   using Base::Base;
 
@@ -418,6 +419,23 @@ class CooperativeMatrixType
                      std::optional<StorageClass> storage = std::nullopt);
   void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
                        std::optional<StorageClass> storage = std::nullopt);
+
+  operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+  ArrayRef<int64_t> getShape() const;
+
+  bool hasRank() const { return true; }
+
+  CooperativeMatrixType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+                                  Type elementType) const {
+    if (shape == std::nullopt)
+      return get(elementType, getRows(), getColumns(), getScope(), getUse());
+    else {
+      assert(shape.value().size() == 2);
+      return get(elementType, shape.value()[0], shape.value()[1], getScope(),
+                 getUse());
+    }
+  }
 };
 
 // SPIR-V matrix type
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 337df3a5a65f0..de2034680cd5f 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -195,7 +195,7 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
 
 struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
   using KeyTy =
-      std::tuple<Type, uint32_t, uint32_t, Scope, CooperativeMatrixUseKHR>;
+      std::tuple<Type, int64_t, int64_t, Scope, CooperativeMatrixUseKHR>;
 
   static CooperativeMatrixTypeStorage *
   construct(TypeStorageAllocator &allocator, const KeyTy &key) {
@@ -204,17 +204,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage {
   }
 
   bool operator==(const KeyTy &key) const {
-    return key == KeyTy(elementType, rows, columns, scope, use);
+    return key == KeyTy(elementType, shape[0], shape[1], scope, use);
   }
 
   CooperativeMatrixTypeStorage(const KeyTy &key)
-      : elementType(std::get<0>(key)), rows(std::get<1>(key)),
-        columns(std::get<2>(key)), scope(std::get<3>(key)),
+      : elementType(std::get<0>(key)),
+        shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)),
         use(std::get<4>(key)) {}
 
   Type elementType;
-  uint32_t rows;
-  uint32_t columns;
+  // [#rows, #columns]
+  SmallVector<int64_t, 2> shape;
   Scope scope;
   CooperativeMatrixUseKHR use;
 };
@@ -231,10 +231,16 @@ Type CooperativeMatrixType::getElementType() const {
   return getImpl()->elementType;
 }
 
-uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; }
+uint32_t CooperativeMatrixType::getRows() const {
+  return static_cast<uint32_t>(getImpl()->shape[0]);
+}
 
 uint32_t CooperativeMatrixType::getColumns() const {
-  return getImpl()->columns;
+  return static_cast<uint32_t>(getImpl()->shape[1]);
+}
+
+ArrayRef<int64_t> CooperativeMatrixType::getShape() const {
+  return getImpl()->shape;
 }
 
 Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; }
diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
index d3e1dbc229ef9..4ae8b70bf43ca 100644
--- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir
@@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" {
 
 spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" {
-  // expected-error @+1 {{op requires the same type for all operands and results}}
+  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
   %q = "spirv.IAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>)
     -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>
@@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>,
 
 spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>,
                  %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" {
-  // expected-error @+1 {{op requires the same type for all operands and results}}
+  // expected-error @+1 {{op failed to verify that cooperative matrix types match}}
   %q = "spirv.FAdd"(%a, %b) :
     (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>)
     -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants