Skip to content

[mlir] Allow all shaped types for arith ops. #99028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pifon2a
Copy link
Contributor

@pifon2a pifon2a commented Jul 16, 2024

Since arith ops support not only scalars, but also vectors and tensors, there is no reason not to support other shaped types.

We need a custom vector type that could be added, divided, etc.

@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2024

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-arith

@llvm/pr-subscribers-mlir-ods

Author: Alexander Belyaev (pifon2a)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/99028.diff

2 Files Affected:

  • (modified) mlir/include/mlir/IR/CommonTypeConstraints.td (+5-2)
  • (modified) mlir/test/Dialect/Arith/ops.mlir (+6)
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..9414c61feb365 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
 
 // Shaped types.
 
+class ShapedTypeOf<list<Type> allowedTypes> :
+  ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
+                      "::mlir::ShapedType">;
+
 def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
                                    "::mlir::ShapedType">;
 
@@ -844,8 +848,7 @@ class NestedTupleOf<list<Type> allowedTypes> :
 // Type constraint for types that are "like" some type or set of types T, that is
 // they're either a T, a vector of Ts, or a tensor of Ts
 class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
-  allowedType.predicate, VectorOf<[allowedType]>.predicate,
-  TensorOf<[allowedType]>.predicate]>,
+  allowedType.predicate, ShapedTypeOf<[allowedType]>.predicate]>,
   name>;
 
 // Temporary constraint to allow gradual transition to supporting 0-D vectors.
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..0e786c211431a 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -13,6 +13,12 @@ func.func @test_addi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
   return %0 : tensor<8x8xi64>
 }
 
+// CHECK-LABEL: test_addi_unranked_tensor
+func.func @test_addi_unranked_tensor(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> {
+  %0 = arith.addi %arg0, %arg1 : tensor<*xi32>
+  return %0 : tensor<*xi32>
+}
+
 // CHECK-LABEL: test_addi_vector
 func.func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
   %0 = arith.addi %arg0, %arg1 : vector<8xi64>

@pifon2a pifon2a requested a review from jreiffers July 16, 2024 12:54
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fundamental change like this deserves an RFC before landing. We had some recent discussions that argued in favor of the opposite direction: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357.

Copy link
Member

@jreiffers jreiffers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like mostly an NFC.

@MaheshRavishankar
Copy link
Contributor

I agree with @kuhar . Seems innocuous but is actually a big change. I dont have a strong opinion either way, but seems like something worth getting wider feedback on.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;

// Shaped types.

class ShapedTypeOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
"::mlir::ShapedType">;
Copy link
Member

@matthias-springer matthias-springer Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem with this is that memrefs could now be passed because they are shaped types. E.g.:

// Should arith.addf allocate a new buffer?
%0 = arith.addf %a, %b : memref<?xf32>

I think we should allow only immutable types. Whether they are shaped types or not may not even be that important.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being a value type is not sufficient (e.g. we shouldn't allow adding tuples or structs). So scalar or shaped value type is probably the most useful/accurate definition.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants