-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] Allow all shaped types for arith ops. #99028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-ods Author: Alexander Belyaev (pifon2a) ChangesFull diff: https://github.com/llvm/llvm-project/pull/99028.diff 2 Files Affected:
diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td
index af4f13dc09360..9414c61feb365 100644
--- a/mlir/include/mlir/IR/CommonTypeConstraints.td
+++ b/mlir/include/mlir/IR/CommonTypeConstraints.td
@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;
// Shaped types.
+class ShapedTypeOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
+ "::mlir::ShapedType">;
+
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
"::mlir::ShapedType">;
@@ -844,8 +848,7 @@ class NestedTupleOf<list<Type> allowedTypes> :
// Type constraint for types that are "like" some type or set of types T, that is
// they're either a T, a vector of Ts, or a tensor of Ts
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
- allowedType.predicate, VectorOf<[allowedType]>.predicate,
- TensorOf<[allowedType]>.predicate]>,
+ allowedType.predicate, ShapedTypeOf<[allowedType]>.predicate]>,
name>;
// Temporary constraint to allow gradual transition to supporting 0-D vectors.
diff --git a/mlir/test/Dialect/Arith/ops.mlir b/mlir/test/Dialect/Arith/ops.mlir
index f684e02344a51..0e786c211431a 100644
--- a/mlir/test/Dialect/Arith/ops.mlir
+++ b/mlir/test/Dialect/Arith/ops.mlir
@@ -13,6 +13,12 @@ func.func @test_addi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
return %0 : tensor<8x8xi64>
}
+// CHECK-LABEL: test_addi_unranked_tensor
+func.func @test_addi_unranked_tensor(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> {
+ %0 = arith.addi %arg0, %arg1 : tensor<*xi32>
+ return %0 : tensor<*xi32>
+}
+
// CHECK-LABEL: test_addi_vector
func.func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
%0 = arith.addi %arg0, %arg1 : vector<8xi64>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A fundamental change like this deserves an RFC before landing. We had some recent discussions that argued in favor of the opposite direction: https://discourse.llvm.org/t/rfc-remove-arith-math-ops-on-tensors/74357.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like mostly an NFC.
I agree with @kuhar . Seems innocuous but is actually a big change. I dont have a strong opinion either way, but seems like something worth getting wider feedback on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>; | |||
|
|||
// Shaped types. | |||
|
|||
class ShapedTypeOf<list<Type> allowedTypes> : | |||
ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped", | |||
"::mlir::ShapedType">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One problem with this is that memrefs could now be passed because they are shaped types. E.g.:
// Should arith.addf allocate a new buffer?
%0 = arith.addf %a, %b : memref<?xf32>
I think we should allow only immutable types. Whether they are shaped types or not may not even be that important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Being a value type is not sufficient (e.g. we shouldn't allow adding tuples or structs). So scalar or shaped value type is probably the most useful/accurate definition.
Since arith ops support not only scalars, but also vectors and tensors, there is no reason not to support other shaped types.
We need a custom vector type that could be added, divided, etc.