Skip to content

[mlir][LLVM] Delete getVectorElementType #134981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mlir/docs/Dialects/LLVM.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,6 @@ compatible with the LLVM dialect:

- `bool LLVM::isCompatibleVectorType(Type)` - checks whether a type is a
vector type compatible with the LLVM dialect;
- `Type LLVM::getVectorElementType(Type)` - returns the element type of any
vector type compatible with the LLVM dialect;
- `llvm::ElementCount LLVM::getVectorNumElements(Type)` - returns the number
of elements in any vector type compatible with the LLVM dialect;
- `Type LLVM::getFixedVectorType(Type, unsigned)` - gets a fixed vector type
Expand Down
14 changes: 8 additions & 6 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def LLVM_MatrixColumnMajorLoadOp : LLVM_OneResultIntrOp<"matrix.column.major.loa
const llvm::DataLayout &dl =
builder.GetInsertBlock()->getModule()->getDataLayout();
llvm::Type *ElemTy = moduleTranslation.convertType(
getVectorElementType(op.getType()));
op.getType().getElementType());
llvm::Align align = dl.getABITypeAlign(ElemTy);
$res = mb.CreateColumnMajorLoad(
ElemTy, $data, align, $stride, $isVolatile, $rows,
Expand Down Expand Up @@ -907,7 +907,7 @@ def LLVM_MatrixColumnMajorStoreOp : LLVM_ZeroResultIntrOp<"matrix.column.major.s
llvm::MatrixBuilder mb(builder);
const llvm::DataLayout &dl =
builder.GetInsertBlock()->getModule()->getDataLayout();
Type elementType = getVectorElementType(op.getMatrix().getType());
Type elementType = op.getMatrix().getType().getElementType();
llvm::Align align = dl.getABITypeAlign(
moduleTranslation.convertType(elementType));
mb.CreateColumnMajorStore(
Expand Down Expand Up @@ -1164,7 +1164,8 @@ def LLVM_vector_insert
let extraClassDeclaration = [{
uint64_t getVectorBitWidth(Type vector) {
return getVectorNumElements(vector).getKnownMinValue() *
getVectorElementType(vector).getIntOrFloatBitWidth();
::llvm::cast<VectorType>(vector).getElementType()
.getIntOrFloatBitWidth();
}
uint64_t getSrcVectorBitWidth() {
return getVectorBitWidth(getSrcvec().getType());
Expand Down Expand Up @@ -1196,7 +1197,8 @@ def LLVM_vector_extract
let extraClassDeclaration = [{
uint64_t getVectorBitWidth(Type vector) {
return getVectorNumElements(vector).getKnownMinValue() *
getVectorElementType(vector).getIntOrFloatBitWidth();
::llvm::cast<VectorType>(vector).getElementType()
.getIntOrFloatBitWidth();
}
uint64_t getSrcVectorBitWidth() {
return getVectorBitWidth(getSrcvec().getType());
Expand All @@ -1216,8 +1218,8 @@ def LLVM_vector_interleave2
"result has twice as many elements as 'vec1'",
And<[CPred<"getVectorNumElements($res.getType()) == "
"getVectorNumElements($vec1.getType()) * 2">,
CPred<"getVectorElementType($vec1.getType()) == "
"getVectorElementType($res.getType())">]>>,
CPred<"::llvm::cast<VectorType>($vec1.getType()).getElementType() == "
"::llvm::cast<VectorType>($res.getType()).getElementType()">]>>,
]>,
Arguments<(ins LLVM_AnyVector:$vec1, LLVM_AnyVector:$vec2)>;

Expand Down
14 changes: 9 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,31 @@ def LLVM_AnyNonAggregate : Type<And<[LLVM_Type.predicate,

// Type constraint accepting any LLVM vector type.
def LLVM_AnyVector : Type<CPred<"::mlir::LLVM::isCompatibleVectorType($_self)">,
"LLVM dialect-compatible vector type">;
"LLVM dialect-compatible vector type",
"::mlir::VectorType">;

// Type constraint accepting any LLVM fixed-length vector type.
def LLVM_AnyFixedVector : Type<CPred<
"!::mlir::LLVM::isScalableVectorType($_self)">,
"LLVM dialect-compatible fixed-length vector type">;
"LLVM dialect-compatible fixed-length vector type",
"::mlir::VectorType">;

// Type constraint accepting any LLVM scalable vector type.
def LLVM_AnyScalableVector : Type<CPred<
"::mlir::LLVM::isScalableVectorType($_self)">,
"LLVM dialect-compatible scalable vector type">;
"LLVM dialect-compatible scalable vector type",
"::mlir::VectorType">;

// Type constraint accepting an LLVM vector type with an additional constraint
// on the vector element type.
class LLVM_VectorOf<Type element> : Type<
And<[LLVM_AnyVector.predicate,
SubstLeaves<
"$_self",
"::mlir::LLVM::getVectorElementType($_self)",
"::llvm::cast<::mlir::VectorType>($_self).getElementType()",
element.predicate>]>,
"LLVM dialect-compatible vector of " # element.summary>;
"LLVM dialect-compatible vector of " # element.summary,
"::mlir::VectorType">;

// Type constraint accepting a constrained type, or a vector of such types.
class LLVM_ScalarOrVectorOf<Type element> :
Expand Down
8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,9 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
//===----------------------------------------------------------------------===//

def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [Pure,
TypesMatchWith<"result type matches vector element type", "vector", "res",
"LLVM::getVectorElementType($_self)">]> {
TypesMatchWith<
"result type matches vector element type", "vector", "res",
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">]> {
let summary = "Extract an element from an LLVM vector.";

let arguments = (ins LLVM_AnyVector:$vector, AnySignlessInteger:$position);
Expand Down Expand Up @@ -881,7 +882,8 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [Pure]> {

def LLVM_InsertElementOp : LLVM_Op<"insertelement", [Pure,
TypesMatchWith<"argument type matches vector element type", "vector",
"value", "LLVM::getVectorElementType($_self)">,
"value",
"::llvm::cast<::mlir::VectorType>($_self).getElementType()">,
AllTypesMatch<["res", "vector"]>]> {
let summary = "Insert an element into an LLVM vector.";

Expand Down
4 changes: 0 additions & 4 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,6 @@ bool isCompatibleFloatingPointType(Type type);
/// dialect pointers and LLVM dialect scalable vector types.
bool isCompatibleVectorType(Type type);

/// Returns the element type of any vector type compatible with the LLVM
/// dialect.
Type getVectorElementType(Type type);

/// Returns the element count of any LLVM-compatible vector type.
llvm::ElementCount getVectorNumElements(Type type);

Expand Down
7 changes: 3 additions & 4 deletions mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,9 @@ static unsigned getBitWidth(Type type) {

/// Returns the bit width of LLVMType integer or vector.
static unsigned getLLVMTypeBitWidth(Type type) {
return cast<IntegerType>((LLVM::isCompatibleVectorType(type)
? LLVM::getVectorElementType(type)
: type))
.getWidth();
if (auto vecTy = dyn_cast<VectorType>(type))
type = vecTy.getElementType();
return cast<IntegerType>(type).getWidth();
}

/// Creates `IntegerAttribute` with all bits set for given type
Expand Down
20 changes: 11 additions & 9 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2734,9 +2734,9 @@ void ShuffleVectorOp::build(OpBuilder &builder, OperationState &state, Value v1,
Value v2, DenseI32ArrayAttr mask,
ArrayRef<NamedAttribute> attrs) {
auto containerType = v1.getType();
auto vType = LLVM::getVectorType(LLVM::getVectorElementType(containerType),
mask.size(),
LLVM::isScalableVectorType(containerType));
auto vType = LLVM::getVectorType(
cast<VectorType>(containerType).getElementType(), mask.size(),
LLVM::isScalableVectorType(containerType));
build(builder, state, vType, v1, v2, mask);
state.addAttributes(attrs);
}
Expand All @@ -2752,8 +2752,9 @@ static ParseResult parseShuffleType(AsmParser &parser, Type v1Type,
if (!LLVM::isCompatibleVectorType(v1Type))
return parser.emitError(parser.getCurrentLocation(),
"expected an LLVM compatible vector type");
resType = LLVM::getVectorType(LLVM::getVectorElementType(v1Type), mask.size(),
LLVM::isScalableVectorType(v1Type));
resType =
LLVM::getVectorType(cast<VectorType>(v1Type).getElementType(),
mask.size(), LLVM::isScalableVectorType(v1Type));
return success();
}

Expand Down Expand Up @@ -3318,7 +3319,7 @@ LogicalResult AtomicRMWOp::verify() {
if (isCompatibleVectorType(valType)) {
if (isScalableVectorType(valType))
return emitOpError("expected LLVM IR fixed vector type");
Type elemType = getVectorElementType(valType);
Type elemType = llvm::cast<VectorType>(valType).getElementType();
if (!isCompatibleFloatingPointType(elemType))
return emitOpError(
"expected LLVM IR floating point type for vector element");
Expand Down Expand Up @@ -3423,9 +3424,10 @@ static LogicalResult verifyExtOp(ExtOp op) {
return op.emitError("input and output vectors are of incompatible shape");
// Because this is a CastOp, the element of vectors is guaranteed to be an
// integer.
inputType = cast<IntegerType>(getVectorElementType(op.getArg().getType()));
outputType =
cast<IntegerType>(getVectorElementType(op.getResult().getType()));
inputType = cast<IntegerType>(
cast<VectorType>(op.getArg().getType()).getElementType());
outputType = cast<IntegerType>(
cast<VectorType>(op.getResult().getType()).getElementType());
} else {
// Because this is a CastOp and arg is not a vector, arg is guaranteed to be
// an integer.
Expand Down
6 changes: 0 additions & 6 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,6 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) {
return false;
}

Type mlir::LLVM::getVectorElementType(Type type) {
auto vecTy = dyn_cast<VectorType>(type);
assert(vecTy && "incompatible with LLVM vector type");
return vecTy.getElementType();
}

llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) {
auto vecTy = dyn_cast<VectorType>(type);
assert(vecTy && "incompatible with LLVM vector type");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Target/LLVMIR/ModuleImport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ static Type getVectorTypeForAttr(Type type, ArrayRef<int64_t> arrayShape = {}) {
}

// An LLVM dialect vector can only contain scalars.
Type elementType = LLVM::getVectorElementType(type);
Type elementType = cast<VectorType>(type).getElementType();
if (!elementType.isIntOrFloat())
return {};

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -515,21 +515,21 @@ func.func @extractvalue_wrong_nesting() {
// -----

func.func @invalid_vector_type_1(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.extractelement %arg2[%arg1 : i32] : f32
}

// -----

func.func @invalid_vector_type_2(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
// expected-error@+1 {{'vector' must be LLVM dialect-compatible vector}}
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.insertelement %arg2, %arg2[%arg1 : i32] : f32
}

// -----

func.func @invalid_vector_type_3(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32) {
// expected-error@+2 {{expected an LLVM compatible vector type}}
// expected-error@+1 {{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.shufflevector %arg2, %arg2 [0, 0, 0, 0, 7] : f32
}

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Target/LLVMIR/llvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ llvm.func @vec_reduce_fmax_intr_wrong_type(%arg0 : vector<4xi32>) -> i32 {
// -----

llvm.func @matrix_load_intr_wrong_type(%ptr : !llvm.ptr, %stride : i32) -> f32 {
// expected-error @below{{op result #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.column.major.load %ptr, <stride=%stride>
{ isVolatile = 0: i1, rows = 3: i32, columns = 16: i32} : f32 from !llvm.ptr stride i32
llvm.return %0 : f32
Expand All @@ -229,7 +229,7 @@ llvm.func @matrix_store_intr_wrong_type(%matrix : vector<48xf32>, %ptr : i32, %s
// -----

llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32) -> vector<12xf32> {
// expected-error @below{{op operand #1 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @+2{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.multiply %arg0, %arg1
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} : (vector<64xf32>, f32) -> vector<12xf32>
llvm.return %0 : vector<12xf32>
Expand All @@ -238,7 +238,7 @@ llvm.func @matrix_multiply_intr_wrong_type(%arg0 : vector<64xf32>, %arg1 : f32)
// -----

llvm.func @matrix_transpose_intr_wrong_type(%matrix : f32) -> vector<48xf32> {
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
%0 = llvm.intr.matrix.transpose %matrix {rows = 3: i32, columns = 16: i32} : f32 into vector<48xf32>
llvm.return %0 : vector<48xf32>
}
Expand Down Expand Up @@ -286,7 +286,7 @@ llvm.func @masked_gather_intr_wrong_type_scalable(%ptrs : vector<7x!llvm.ptr>, %
// -----

llvm.func @masked_scatter_intr_wrong_type(%vec : f32, %ptrs : vector<7x!llvm.ptr>, %mask : vector<7xi1>) {
// expected-error @below{{op operand #0 must be LLVM dialect-compatible vector type, but got 'f32'}}
// expected-error @below{{invalid kind of type specified: expected builtin.vector, but found 'f32'}}
llvm.intr.masked.scatter %vec, %ptrs, %mask { alignment = 1: i32} : f32, vector<7xi1> into vector<7x!llvm.ptr>
llvm.return
}
Expand Down