Skip to content

[mlir] Add encoding attribute to VectorType. #99029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,13 @@ class VectorType::Builder {
/// Build from another VectorType.
explicit Builder(VectorType other)
: elementType(other.getElementType()), shape(other.getShape()),
scalableDims(other.getScalableDims()) {}
scalableDims(other.getScalableDims()), encoding(other.getEncoding()) {}

/// Build from scratch.
Builder(ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims = {})
: elementType(elementType), shape(shape), scalableDims(scalableDims) {}
ArrayRef<bool> scalableDims = {}, Attribute encoding = nullptr)
: elementType(elementType), shape(shape), scalableDims(scalableDims),
encoding(encoding) {}

Builder &setShape(ArrayRef<int64_t> newShape,
ArrayRef<bool> newIsScalableDim = {}) {
Expand Down Expand Up @@ -342,14 +343,20 @@ class VectorType::Builder {
return *this;
}

Builder &setEncoding(Attribute newEncoding) {
encoding = newEncoding;
return *this;
}

operator VectorType() {
return VectorType::get(shape, elementType, scalableDims);
return VectorType::get(shape, elementType, scalableDims, encoding);
}

private:
Type elementType;
CopyOnWriteArrayRef<int64_t> shape;
CopyOnWriteArrayRef<bool> scalableDims;
Attribute encoding;
};

/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
Expand Down
17 changes: 14 additions & 3 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
vector-dim-list := (static-dim-list `x`)?
static-dim-list ::= static-dim (`x` static-dim)*
static-dim ::= (decimal-literal | `[` decimal-literal `]`)
encoding ::= attribute-value
```

The vector type represents a SIMD style vector used by target-specific
Expand All @@ -1072,6 +1073,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
Vector shapes must be positive decimal integers. 0D vectors are allowed by
omitting the dimension: `vector<f32>`.

The `encoding` attribute provides additional information on the vector.
An empty attribute denotes a straightforward vector without any specific
structure.

Note: hexadecimal integer literals are not allowed in vector type
declarations, `vector<0x42xi32>` is invalid because it is interpreted as a
2D vector with shape `(0, 42)` and zero shapes are not allowed.
Expand All @@ -1094,17 +1099,22 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
// A 3D mixed fixed/scalable vector in which only the inner dimension is
// scalable.
vector<2x[4]x8xf32>

// Vector with an encoding attribute (where #ENCODING is a named alias).
vector<4x2xf64, #ENCODING>
```
}];
let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
"Type":$elementType,
ArrayRefParameter<"bool">:$scalableDims
ArrayRefParameter<"bool">:$scalableDims,
"Attribute":$encoding
);
let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType,
CArg<"ArrayRef<bool>", "{}">:$scalableDims
CArg<"ArrayRef<bool>", "{}">:$scalableDims,
CArg<"Attribute", "{}">:$encoding
), [{
// While `scalableDims` is optional, its default value should be
// `false` for every dim in `shape`.
Expand All @@ -1113,7 +1123,8 @@ def Builtin_Vector : Builtin_Type<"Vector", "vector", [ShapedTypeInterface], "Ty
isScalableVec.resize(shape.size(), false);
scalableDims = isScalableVec;
}
return $_get(elementType.getContext(), shape, elementType, scalableDims);
auto ctx = elementType.getContext();
return $_get(ctx, shape, elementType, scalableDims, encoding);
}]>
];
let extraClassDeclaration = [{
Expand Down
20 changes: 18 additions & 2 deletions mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ Type Parser::parseTupleType() {
/// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>`
/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)?
/// static-dim-list ::= decimal-literal (`x` decimal-literal)*
/// encoding ::= attribute-value
///
VectorType Parser::parseVectorType() {
consumeToken(Token::kw_vector);
Expand All @@ -467,14 +468,29 @@ VectorType Parser::parseVectorType() {
// Parse the element type.
auto typeLoc = getToken().getLoc();
auto elementType = parseType();

// Parse an optional encoding attribute.
Attribute encoding;
if (consumeIf(Token::comma)) {
auto parseResult = parseOptionalAttribute(encoding);
if (parseResult.has_value()) {
if (failed(parseResult.value()))
return nullptr;
if (auto v = dyn_cast_or_null<VerifiableTensorEncoding>(encoding)) {
if (failed(v.verifyEncoding(dimensions, elementType,
[&] { return emitError(); })))
return nullptr;
}
}
}

if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;

if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;

return VectorType::get(dimensions, elementType, scalableDims);
return VectorType::get(dimensions, elementType, scalableDims, encoding);
}

/// Parse a dimension list in a vector type. This populates the dimension list.
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2622,6 +2622,11 @@ void AsmPrinter::Impl::printTypeImpl(Type type) {
os << 'x';
}
printType(vectorTy.getElementType());
// Only print the encoding attribute value if set.
if (vectorTy.getEncoding()) {
os << ", ";
printAttribute(vectorTy.getEncoding());
}
os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
Expand Down
8 changes: 6 additions & 2 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,

LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
ArrayRef<bool> scalableDims) {
ArrayRef<bool> scalableDims,
Attribute encoding) {
if (!isValidElementType(elementType))
return emitError()
<< "vector elements must be int/index/float type but got "
Expand All @@ -242,6 +243,9 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "number of dims must match, got "
<< scalableDims.size() << " and " << shape.size();

if (auto v = llvm::dyn_cast_or_null<VerifiableTensorEncoding>(encoding))
if (failed(v.verifyEncoding(shape, elementType, emitError)))
return failure();
return success();
}

Expand All @@ -260,7 +264,7 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return VectorType::get(shape.value_or(getShape()), elementType,
getScalableDims());
getScalableDims(), getEncoding());
}

//===----------------------------------------------------------------------===//
Expand Down
3 changes: 3 additions & 0 deletions mlir/test/IR/parser.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ func.func private @float_types(f80, f128)
// CHECK: func private @vectors(vector<f32>, vector<1xf32>, vector<2x4xf32>)
func.func private @vectors(vector<f32>, vector<1 x f32>, vector<2x4xf32>)

// CHECK: func private @vector_encoding(vector<16x32xf64, "indexed">)
func.func private @vector_encoding(vector<16x32xf64, "indexed">)

// CHECK: func private @tensors(tensor<*xf32>, tensor<*xvector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor<i8>)
func.func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>,
tensor<1x?x4x?x?xi32>, tensor<i8>)
Expand Down
Loading