Skip to content

[mlir] Allow all shaped types for arith ops. #99028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,10 @@ def AnyScalableVector : ScalableVectorOf<[AnyType]>;

// Shaped types.

class ShapedTypeOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsShapedTypePred, "shaped",
"::mlir::ShapedType">;
Copy link
Member

@matthias-springer matthias-springer Jul 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One problem with this is that memrefs could now be passed because they are shaped types. E.g.:

// Should arith.addf allocate a new buffer?
%0 = arith.addf %a, %b : memref<?xf32>

I think we should allow only immutable types. Whether they are shaped types or not may not even be that important.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Being a value type is not sufficient (e.g. we shouldn't allow adding tuples or structs). So scalar or shaped value type is probably the most useful/accurate definition.


def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
"::mlir::ShapedType">;

Expand Down Expand Up @@ -842,10 +846,9 @@ class NestedTupleOf<list<Type> allowedTypes> :
// Common type constraints
//===----------------------------------------------------------------------===//
// Type constraint for types that are "like" some type or set of types T, that is
// they're either a T, a vector of Ts, or a tensor of Ts
// they're either a T or a shaped type of Ts
class TypeOrContainer<Type allowedType, string name> : TypeConstraint<Or<[
allowedType.predicate, VectorOf<[allowedType]>.predicate,
TensorOf<[allowedType]>.predicate]>,
allowedType.predicate, ShapedTypeOf<[allowedType]>.predicate]>,
name>;

// Temporary constraint to allow gradual transition to supporting 0-D vectors.
Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/Arith/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ func.func @test_addi_tensor(%arg0 : tensor<8x8xi64>, %arg1 : tensor<8x8xi64>) ->
return %0 : tensor<8x8xi64>
}

// CHECK-LABEL: test_addi_unranked_tensor
func.func @test_addi_unranked_tensor(%arg0 : tensor<*xi32>, %arg1 : tensor<*xi32>) -> tensor<*xi32> {
%0 = arith.addi %arg0, %arg1 : tensor<*xi32>
return %0 : tensor<*xi32>
}

// CHECK-LABEL: test_addi_vector
func.func @test_addi_vector(%arg0 : vector<8xi64>, %arg1 : vector<8xi64>) -> vector<8xi64> {
%0 = arith.addi %arg0, %arg1 : vector<8xi64>
Expand Down
Loading