Skip to content

Commit fe283a1

Browse files
authored
[mlir][llvm] Fix elem type passing into getelementptr (#68136)
As was correctly pointed out by @azteca1998, the element type for a `llvm.getelementptr` was only read when using an attribute and not when using a type. As pointed out in #63832 (comment), the translation to LLVM would work for ```mlir llvm.func @main(%0 : !llvm.ptr) -> !llvm.ptr { %1 = llvm.getelementptr %0[0] { elem_type = !llvm.ptr } : (!llvm.ptr) -> !llvm.ptr llvm.return %1 : !llvm.ptr } ``` but not for ```mlir llvm.func @main(%0 : !llvm.ptr) -> !llvm.ptr<ptr> { %1 = llvm.getelementptr %0[0] : (!llvm.ptr) -> !llvm.ptr<ptr> llvm.return %1 : !llvm.ptr<ptr> } ``` This was caused by the `LLVM_GEPOp` builder only reading the type from the attribute (`{ elem_type = !llvm.ptr }`), but not from the pointer type (`!llvm.ptr<ptr>`). Fixes #63832. EDIT: During review Markus Böck pointed out that this bugfix adds new functionality for typed pointers, but this functionality shouldn't be there in the first place. In response, Oleksandr "Alex" Zinenko pointed out that this is okay for now since the typed pointers will be removed in an upcoming release anyway, so it's best to merge this PR and spend time on the removal instead.
1 parent 77feba5 commit fe283a1

File tree

3 files changed

+22
-5
lines changed

3 files changed

+22
-5
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
303303
indices.push_back(
304304
builder.getInt32(valueOrAttr.get<IntegerAttr>().getInt()));
305305
}
306-
Type baseElementType = op.getSourceElementType();
307-
llvm::Type *elementType = moduleTranslation.convertType(baseElementType);
306+
307+
Type elemTypeFromAttr = op.getSourceElementType();
308+
auto ptrType = ::llvm::cast<LLVMPointerType>(op.getType());
309+
Type elemTypeFromPtrType = ptrType.getElementType();
310+
311+
llvm::Type *elementType = moduleTranslation.convertType(
312+
elemTypeFromAttr ? elemTypeFromAttr : elemTypeFromPtrType);
308313
$res = builder.CreateGEP(elementType, $base, indices, "", $inbounds);
309314
}];
310315
let assemblyFormat = [{

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,17 @@ ParseResult AllocaOp::parse(OpAsmParser &parser, OperationState &result) {
287287
}
288288

289289
/// Checks that the elemental type is present in either the pointer type or
290-
/// the attribute, but not both.
290+
/// the attribute, but not in none or both.
291291
static LogicalResult verifyOpaquePtr(Operation *op, LLVMPointerType ptrType,
292292
std::optional<Type> ptrElementType) {
293-
if (ptrType.isOpaque() && !ptrElementType.has_value()) {
293+
bool typePresentInPointerType = !ptrType.isOpaque();
294+
bool typePresentInAttribute = ptrElementType.has_value();
295+
296+
if (!typePresentInPointerType && !typePresentInAttribute) {
294297
return op->emitOpError() << "expected '" << kElemTypeAttrName
295298
<< "' attribute if opaque pointer type is used";
296299
}
297-
if (!ptrType.isOpaque() && ptrElementType.has_value()) {
300+
if (typePresentInPointerType && typePresentInAttribute) {
298301
return op->emitOpError()
299302
<< "unexpected '" << kElemTypeAttrName
300303
<< "' attribute when non-opaque pointer type is used";

mlir/test/Target/LLVMIR/opaque-ptr.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ llvm.func @opaque_ptr_gep_struct(%arg0: !llvm.ptr, %arg1: i32) -> !llvm.ptr {
4242
llvm.return %0 : !llvm.ptr
4343
}
4444

45+
// CHECK-LABEL: @opaque_ptr_elem_type
46+
llvm.func @opaque_ptr_elem_type(%0: !llvm.ptr) -> !llvm.ptr {
47+
// CHECK: getelementptr ptr, ptr
48+
%1 = llvm.getelementptr %0[0] { elem_type = !llvm.ptr } : (!llvm.ptr) -> !llvm.ptr
49+
// CHECK: getelementptr ptr, ptr
50+
%2 = llvm.getelementptr %0[0] : (!llvm.ptr) -> !llvm.ptr<ptr>
51+
llvm.return %1 : !llvm.ptr
52+
}
53+
4554
// CHECK-LABEL: @opaque_ptr_matrix_load_store
4655
llvm.func @opaque_ptr_matrix_load_store(%ptr: !llvm.ptr, %stride: i64) -> vector<48 x f32> {
4756
// CHECK: call <48 x float> @llvm.matrix.column.major.load.v48f32.i64

0 commit comments

Comments
 (0)