Skip to content

Commit c137271

Browse files
committed
[OMPIRBuilder][OpenMP][LLVM] Modify and use ReplaceConstant utility in convertTarget
This PR seeks to expand/replace the Constant -> Instruction conversion that needs to occur inside of the OpenMP Target kernel generation to allow kernel argument replacement of uses within the kernel (cannot replace constant uses within constant expressions with non-constants). It does so by making use of the new-ish utility convertUsersOfConstantsToInstructions which is a much more expansive version of what the smaller "version" of the function I wrote does, effectively expanding uses of the input argument that are constant expressions into instructions so that we can replace with the appropriate kernel argument. Also alters convertUsersOfConstantsToInstructions to optionally leave dead constants alone is necessary when lowering from MLIR as we cannot be sure we can remove the constants at this stage, even if rewritten to instructions the ModuleTranslation may maintain links to the original constants and utilise them in further lowering steps (as when we're lowering the kernel, the module is still in the process of being lowered). This can result in unusual ICEs later. These dead constants can be tidied up later (and appear to be in subsequent lowering from checking with emit-llvm). The one possible downside to this replacement is that the constant -> instruction rewriting is no longer constrained to within the kernel, it will expand the available uses of an input argument that is constant and has constant uses in the module. This hasn't lowered the correctness of the examples I have tested with, however, it may impact performance, a possibility in the future may be to optionally constrain rewrites of uses of constants in convertUsersOfConstantsToInstructions to a provided llvm::Function.
1 parent 0295c2a commit c137271

File tree

5 files changed

+79
-37
lines changed

5 files changed

+79
-37
lines changed

llvm/include/llvm/IR/ReplaceConstant.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ class Constant;
2121

2222
/// Replace constant expressions users of the given constants with
2323
/// instructions. Return whether anything was changed.
24-
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts);
24+
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
25+
bool RemoveDeadConstants = true);
2526

2627
} // end namespace llvm
2728

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "llvm/IR/MDBuilder.h"
4141
#include "llvm/IR/Metadata.h"
4242
#include "llvm/IR/PassManager.h"
43+
#include "llvm/IR/ReplaceConstant.h"
4344
#include "llvm/IR/Value.h"
4445
#include "llvm/MC/TargetRegistry.h"
4546
#include "llvm/Support/CommandLine.h"
@@ -5092,27 +5093,6 @@ FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
50925093
return getOrCreateRuntimeFunction(M, Name);
50935094
}
50945095

5095-
static void replaceConstatExprUsesInFuncWithInstr(ConstantExpr *ConstExpr,
5096-
Function *Func) {
5097-
for (User *User : make_early_inc_range(ConstExpr->users())) {
5098-
if (auto *Instr = dyn_cast<Instruction>(User)) {
5099-
if (Instr->getFunction() == Func) {
5100-
Instruction *ConstInst = ConstExpr->getAsInstruction();
5101-
ConstInst->insertBefore(*Instr->getParent(), Instr->getIterator());
5102-
Instr->replaceUsesOfWith(ConstExpr, ConstInst);
5103-
}
5104-
}
5105-
}
5106-
}
5107-
5108-
static void replaceConstantValueUsesInFuncWithInstr(llvm::Value *Input,
5109-
Function *Func) {
5110-
for (User *User : make_early_inc_range(Input->users()))
5111-
if (auto *Const = dyn_cast<Constant>(User))
5112-
if (auto *ConstExpr = dyn_cast<ConstantExpr>(Const))
5113-
replaceConstatExprUsesInFuncWithInstr(ConstExpr, Func);
5114-
}
5115-
51165096
static Function *createOutlinedFunction(
51175097
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
51185098
SmallVectorImpl<Value *> &Inputs,
@@ -5191,17 +5171,23 @@ static Function *createOutlinedFunction(
51915171

51925172
// Things like GEP's can come in the form of Constants. Constants and
51935173
// ConstantExpr's do not have access to the knowledge of what they're
5194-
// contained in, so we must dig a little to find an instruction so we can
5195-
// tell if they're used inside of the function we're outlining. We also
5196-
// replace the original constant expression with a new instruction
5174+
// contained in, so we must dig a little to find an instruction so we
5175+
// can tell if they're used inside of the function we're outlining. We
5176+
// also replace the original constant expression with a new instruction
51975177
// equivalent; an instruction as it allows easy modification in the
5198-
// following loop, as we can now know the constant (instruction) is owned by
5199-
// our target function and replaceUsesOfWith can now be invoked on it
5200-
// (cannot do this with constants it seems). A brand new one also allows us
5201-
// to be cautious as it is perhaps possible the old expression was used
5202-
// inside of the function but exists and is used externally (unlikely by the
5203-
// nature of a Constant, but still).
5204-
replaceConstantValueUsesInFuncWithInstr(Input, Func);
5178+
// following loop, as we can now know the constant (instruction) is
5179+
// owned by our target function and replaceUsesOfWith can now be invoked
5180+
// on it (cannot do this with constants it seems). A brand new one also
5181+
// allows us to be cautious as it is perhaps possible the old expression
5182+
// was used inside of the function but exists and is used externally
5183+
// (unlikely by the nature of a Constant, but still).
5184+
// NOTE: We cannot remove dead constants that have been rewritten to
5185+
// instructions at this stage, we run the risk of breaking later lowering
5186+
// by doing so as we could still be in the process of lowering the module
5187+
// from MLIR to LLVM-IR and the MLIR lowering may still require the original
5188+
// constants we have created rewritten versions of.
5189+
if (auto *Const = dyn_cast<Constant>(Input))
5190+
convertUsersOfConstantsToInstructions({Const}, false);
52055191

52065192
// Collect all the instructions
52075193
for (User *User : make_early_inc_range(Input->users()))

llvm/lib/IR/ReplaceConstant.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ static SmallVector<Instruction *, 4> expandUser(BasicBlock::iterator InsertPt,
4949
return NewInsts;
5050
}
5151

52-
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts) {
52+
bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts,
53+
bool RemoveDeadConstants) {
5354
// Find all expandable direct users of Consts.
5455
SmallVector<Constant *> Stack;
5556
for (Constant *C : Consts)
@@ -102,8 +103,9 @@ bool convertUsersOfConstantsToInstructions(ArrayRef<Constant *> Consts) {
102103
}
103104
}
104105

105-
for (Constant *C : Consts)
106-
C->removeDeadConstantUsers();
106+
if (RemoveDeadConstants)
107+
for (Constant *C : Consts)
108+
C->removeDeadConstantUsers();
107109

108110
return Changed;
109111
}

mlir/test/Target/LLVMIR/omptarget-fortran-allocatable-types-host.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ module attributes {omp.is_target_device = false} {
6666

6767
// CHECK: define void @_QQmain()
6868
// CHECK: %[[SCALAR_ALLOCA:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8 }, i64 1, align 8
69-
// CHECK: %[[FULL_ARR_SIZE5:.*]] = load i64, ptr getelementptr ({ ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr @[[FULL_ARR_GLOB]], i32 0, i32 7, i64 0, i32 1), align 4
69+
// CHECK: %[[FULL_ARR_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr @[[FULL_ARR_GLOB]], i32 0, i32 7, i64 0, i32 1
70+
// CHECK: %[[FULL_ARR_SIZE5:.*]] = load i64, ptr %[[FULL_ARR_GEP]], align 4
7071
// CHECK: %[[FULL_ARR_SIZE4:.*]] = sub i64 %[[FULL_ARR_SIZE5]], 1
71-
// CHECK: %[[ARR_SECT_OFFSET3:.*]] = load i64, ptr getelementptr ({ ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr @[[ARR_SECT_GLOB]], i32 0, i32 7, i64 0, i32 0), align 4
72+
// CHECK: %[[ARR_SECT_GEP:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, ptr @[[ARR_SECT_GLOB]], i32 0, i32 7, i64 0, i32 0
73+
// CHECK: %[[ARR_SECT_OFFSET3:.*]] = load i64, ptr %[[ARR_SECT_GEP]], align 4
7274
// CHECK: %[[ARR_SECT_OFFSET2:.*]] = sub i64 2, %[[ARR_SECT_OFFSET3]]
7375
// CHECK: %[[ARR_SECT_SIZE4:.*]] = sub i64 5, %[[ARR_SECT_OFFSET3]]
7476
// CHECK: %[[SCALAR_BASE:.*]] = getelementptr { ptr, i64, i32, i8, i8, i8, i8 }, ptr %[[SCALAR_ALLOCA]], i32 0, i32 0
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
! Offloading test which maps a specific element of a
2+
! derived type to the device and then accesses the
3+
! element alongside an individual element of an array
4+
! that the derived type contains. In particular, this
5+
! test helps to check that we can replace the constants
6+
! within the kernel with instructions and then replace
7+
! these instructions with the kernel parameters.
8+
! REQUIRES: flang
9+
! UNSUPPORTED: nvptx64-nvidia-cuda
10+
! UNSUPPORTED: nvptx64-nvidia-cuda-LTO
11+
! UNSUPPORTED: aarch64-unknown-linux-gnu
12+
! UNSUPPORTED: aarch64-unknown-linux-gnu-LTO
13+
! UNSUPPORTED: x86_64-pc-linux-gnu
14+
! UNSUPPORTED: x86_64-pc-linux-gnu-LTO
15+
16+
! RUN: %libomptarget-compile-fortran-run-and-check-generic
17+
module test_0
18+
type dtype
19+
integer elements(20)
20+
integer value
21+
end type dtype
22+
23+
type (dtype) array_dtype(5)
24+
contains
25+
26+
subroutine assign()
27+
implicit none
28+
!$omp target map(tofrom: array_dtype(5))
29+
array_dtype(5)%elements(5) = 500
30+
!$omp end target
31+
end subroutine
32+
33+
subroutine add()
34+
implicit none
35+
36+
!$omp target map(tofrom: array_dtype(5))
37+
array_dtype(5)%elements(5) = array_dtype(5)%elements(5) + 500
38+
!$omp end target
39+
end subroutine
40+
end module test_0
41+
42+
program main
43+
use test_0
44+
45+
call assign()
46+
call add()
47+
48+
print *, array_dtype(5)%elements(5)
49+
end program
50+
51+
! CHECK: 1000

0 commit comments

Comments
 (0)