Skip to content

[mlir][OpenMP] implement SIMD reduction #146671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
}
};
auto checkReduction = [&todo](auto op, LogicalResult &result) {
if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
if (isa<omp::TeamsOp>(op))
if (!op.getReductionVars().empty() || op.getReductionByref() ||
op.getReductionSyms())
result = todo("reduction");
Expand Down Expand Up @@ -2864,6 +2864,17 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,

PrivateVarsInfo privateVarsInfo(simdOp);

MutableArrayRef<BlockArgument> reductionArgs =
cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
DenseMap<Value, llvm::Value *> reductionVariableMap;
SmallVector<llvm::Value *> privateReductionVariables(
simdOp.getNumReductionVars());
SmallVector<DeferredStore> deferredStores;
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(simdOp, reductionDecls);
llvm::ArrayRef<bool> isByRef = getIsByRef(simdOp.getReductionByref());
assert(isByRef.size() == simdOp.getNumReductionVars());

llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);

Expand All @@ -2872,11 +2883,27 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
if (handleError(afterAllocas, opInst).failed())
return failure();

if (failed(allocReductionVars(simdOp, reductionArgs, builder,
moduleTranslation, allocaIP, reductionDecls,
privateReductionVariables, reductionVariableMap,
deferredStores, isByRef)))
return failure();

if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo),
opInst)
.failed())
return failure();

// TODO: no call to copyFirstPrivateVars?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems to be missing indeed. I'm guessing firstprivate probably doesn't currently work for simd, then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll fix it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh it turns out firstprivate isn't listed as an allowed clause for simd. I will update the comment.

Copy link
Contributor Author

@tblah tblah Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dialect does allow it so I have done this in a separate PR which I will post after this is merged. Edit: #146734


assert(afterAllocas.get()->getSinglePredecessor());
if (failed(initReductionVars(simdOp, reductionArgs, builder,
moduleTranslation,
afterAllocas.get()->getSinglePredecessor(),
reductionDecls, privateReductionVariables,
reductionVariableMap, isByRef, deferredStores)))
return failure();

llvm::ConstantInt *simdlen = nullptr;
if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
simdlen = builder.getInt64(simdlenVar.value());
Expand Down Expand Up @@ -2921,6 +2948,50 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
: nullptr,
order, simdlen, safelen);

// We now need to reduce the per-simd-lane reduction variable into the
// original variable. This works a bit differently to other reductions (e.g.
// wsloop) because we don't need to call into the OpenMP runtime to handle
// threads: everything happened in this one thread.
for (auto [i, tuple] : llvm::enumerate(
llvm::zip(reductionDecls, isByRef, simdOp.getReductionVars(),
privateReductionVariables))) {
auto [decl, byRef, reductionVar, privateReductionVar] = tuple;

OwningReductionGen gen = makeReductionGen(decl, builder, moduleTranslation);
llvm::Value *originalVariable = moduleTranslation.lookupValue(reductionVar);
llvm::Type *reductionType = moduleTranslation.convertType(decl.getType());

// We have one less load for by-ref case because that load is now inside of
// the reduction region.
llvm::Value *redValue = originalVariable;
if (!byRef)
redValue =
builder.CreateLoad(reductionType, redValue, "red.value." + Twine(i));
llvm::Value *privateRedValue = builder.CreateLoad(
reductionType, privateReductionVar, "red.private.value." + Twine(i));
llvm::Value *reduced;

auto res = gen(builder.saveIP(), redValue, privateRedValue, reduced);
if (failed(handleError(res, opInst)))
return failure();
builder.restoreIP(res.get());

// For by-ref case, the store is inside of the reduction region.
if (!byRef)
builder.CreateStore(reduced, originalVariable);
}

// After the construct, deallocate private reduction variables.
SmallVector<Region *> reductionRegions;
llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
[](omp::DeclareReductionOp reductionDecl) {
return &reductionDecl.getCleanupRegion();
});
if (failed(inlineOmpRegionCleanup(reductionRegions, privateReductionVariables,
moduleTranslation, builder,
"omp.reduction.cleanup")))
return failure();

return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
privateVarsInfo.llvmVars,
privateVarsInfo.privatizers);
Expand Down
98 changes: 98 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-simd-reduction-byref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s

llvm.func @init(%arg0: !llvm.ptr {llvm.nocapture}, %arg1: !llvm.ptr {llvm.nocapture}) {
llvm.return
}
llvm.func @combine(%arg0: !llvm.ptr {llvm.nocapture}, %arg1: !llvm.ptr {llvm.nocapture}) {
llvm.return
}
llvm.func @cleanup(%arg0: !llvm.ptr {llvm.nocapture}) {
llvm.return
}
omp.private {type = private} @_QFsimd_reductionEi_private_i32 : i32
omp.declare_reduction @add_reduction_byref_box_2xf32 : !llvm.ptr alloc {
%0 = llvm.mlir.constant(1 : i64) : i64
%1 = llvm.alloca %0 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
omp.yield(%1 : !llvm.ptr)
} init {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
llvm.call @init(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
omp.yield(%arg1 : !llvm.ptr)
} combiner {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
llvm.call @combine(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
omp.yield(%arg0 : !llvm.ptr)
} cleanup {
^bb0(%arg0: !llvm.ptr):
llvm.call @cleanup(%arg0) : (!llvm.ptr) -> ()
omp.yield
}
llvm.func @_QPsimd_reduction(%arg0: !llvm.ptr {fir.bindc_name = "a", llvm.nocapture}, %arg1: !llvm.ptr {fir.bindc_name = "sum", llvm.nocapture}) {
%0 = llvm.mlir.constant(1024 : i32) : i32
%1 = llvm.mlir.constant(1 : i32) : i32
%2 = llvm.mlir.constant(1 : i64) : i64
%3 = llvm.alloca %2 x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>)> : (i64) -> !llvm.ptr
%4 = llvm.alloca %2 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
omp.simd private(@_QFsimd_reductionEi_private_i32 %4 -> %arg2 : !llvm.ptr) reduction(byref @add_reduction_byref_box_2xf32 %3 -> %arg3 : !llvm.ptr) {
omp.loop_nest (%arg4) : i32 = (%1) to (%0) inclusive step (%1) {
llvm.store %arg4, %arg2 : i32, !llvm.ptr
omp.yield
}
}
llvm.return
}

// CHECK-LABEL: define void @_QPsimd_reduction
// CHECK: %[[MOLD:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
// CHECK: %[[ORIG_I:.*]] = alloca i32, i64 1, align 4
// CHECK: %[[PRIV_I:.*]] = alloca i32, align 4
// CHECK: %[[RED_VAR:.*]] = alloca { ptr, i64, i32, i8, i8, i8, i8, [1 x [3 x i64]] }, i64 1, align 8
// CHECK: %[[PTR_RED_VAR:.*]] = alloca ptr, align 8
// CHECK: br label %[[VAL_5:.*]]
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_6:.*]]
// CHECK: br label %[[VAL_7:.*]]
// CHECK: entry: ; preds = %[[VAL_5]]
// CHECK: br label %[[VAL_8:.*]]
// CHECK: omp.private.init: ; preds = %[[VAL_7]]
// CHECK: br label %[[VAL_9:.*]]
// CHECK: omp.reduction.init: ; preds = %[[VAL_8]]
// CHECK: store ptr %[[RED_VAR]], ptr %[[PTR_RED_VAR]], align 8
// CHECK: call void @init(ptr %[[MOLD]], ptr %[[RED_VAR]])
// CHECK: br label %[[VAL_10:.*]]
// CHECK: omp.simd.region: ; preds = %[[VAL_9]]
// CHECK: br label %[[VAL_11:.*]]
// CHECK: omp_loop.preheader: ; preds = %[[VAL_10]]
// CHECK: br label %[[VAL_12:.*]]
// CHECK: omp_loop.header: ; preds = %[[VAL_13:.*]], %[[VAL_11]]
// CHECK: %[[VAL_14:.*]] = phi i32 [ 0, %[[VAL_11]] ], [ %[[VAL_15:.*]], %[[VAL_13]] ]
// CHECK: br label %[[VAL_16:.*]]
// CHECK: omp_loop.cond: ; preds = %[[VAL_12]]
// CHECK: %[[VAL_17:.*]] = icmp ult i32 %[[VAL_14]], 1024
// CHECK: br i1 %[[VAL_17]], label %[[VAL_18:.*]], label %[[VAL_19:.*]]
// CHECK: omp_loop.body: ; preds = %[[VAL_16]]
// CHECK: %[[VAL_20:.*]] = mul i32 %[[VAL_14]], 1
// CHECK: %[[VAL_21:.*]] = add i32 %[[VAL_20]], 1
// CHECK: br label %[[VAL_22:.*]]
// CHECK: omp.loop_nest.region: ; preds = %[[VAL_18]]
// CHECK: store i32 %[[VAL_21]], ptr %[[PRIV_I]], align 4, !llvm.access.group ![[ACCESS_GROUP:.*]]
// CHECK: br label %[[VAL_23:.*]]
// CHECK: omp.region.cont1: ; preds = %[[VAL_22]]
// CHECK: br label %[[VAL_13]]
// CHECK: omp_loop.inc: ; preds = %[[VAL_23]]
// CHECK: %[[VAL_15]] = add nuw i32 %[[VAL_14]], 1
// CHECK: br label %[[VAL_12]], !llvm.loop ![[LOOP:.*]]
// CHECK: omp_loop.exit: ; preds = %[[VAL_16]]
// CHECK: br label %[[VAL_24:.*]]
// CHECK: omp_loop.after: ; preds = %[[VAL_19]]
// CHECK: br label %[[VAL_25:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_24]]
// CHECK: %[[RED_VAR2:.*]] = load ptr, ptr %[[PTR_RED_VAR]], align 8
// CHECK: call void @combine(ptr %[[MOLD]], ptr %[[RED_VAR2]])
// CHECK: %[[RED_VAR3:.*]] = load ptr, ptr %[[PTR_RED_VAR]], align 8
// CHECK: call void @cleanup(ptr %[[RED_VAR3]])
// CHECK: ret void

// CHECK: ![[ACCESS_GROUP]] = distinct !{}
// CHECK: ![[LOOP]] = distinct !{![[LOOP]], ![[PARALLEL_ACCESS:.*]], ![[VECTORIZE:.*]]}
// CHECK: ![[PARALLEL_ACCESS]] = !{!"llvm.loop.parallel_accesses", ![[ACCESS_GROUP]]}
// CHECK: ![[VECTORIZE]] = !{!"llvm.loop.vectorize.enable", i1 true}
96 changes: 96 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-simd-reduction-simple.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s

omp.private {type = private} @_QFsimd_reductionEi_private_i32 : i32
omp.declare_reduction @add_reduction_f32 : f32 init {
^bb0(%arg0: f32):
%0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
omp.yield(%0 : f32)
} combiner {
^bb0(%arg0: f32, %arg1: f32):
%0 = llvm.fadd %arg0, %arg1 {fastmathFlags = #llvm.fastmath<contract>} : f32
omp.yield(%0 : f32)
}
llvm.func @_QPsimd_reduction(%arg0: !llvm.ptr {fir.bindc_name = "a", llvm.nocapture}, %arg1: !llvm.ptr {fir.bindc_name = "sum", llvm.nocapture}) {
%0 = llvm.mlir.constant(0.000000e+00 : f32) : f32
%1 = llvm.mlir.constant(1 : i32) : i32
%2 = llvm.mlir.constant(1024 : i32) : i32
%3 = llvm.mlir.constant(1 : i64) : i64
%4 = llvm.alloca %3 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
llvm.store %0, %arg1 : f32, !llvm.ptr
omp.simd private(@_QFsimd_reductionEi_private_i32 %4 -> %arg2 : !llvm.ptr) reduction(@add_reduction_f32 %arg1 -> %arg3 : !llvm.ptr) {
omp.loop_nest (%arg4) : i32 = (%1) to (%2) inclusive step (%1) {
llvm.store %arg4, %arg2 : i32, !llvm.ptr
%5 = llvm.load %arg3 : !llvm.ptr -> f32
%6 = llvm.load %arg2 : !llvm.ptr -> i32
%7 = llvm.sext %6 : i32 to i64
%8 = llvm.sub %7, %3 overflow<nsw> : i64
%9 = llvm.getelementptr %arg0[%8] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%10 = llvm.load %9 : !llvm.ptr -> f32
%11 = llvm.fadd %5, %10 {fastmathFlags = #llvm.fastmath<contract>} : f32
llvm.store %11, %arg3 : f32, !llvm.ptr
omp.yield
}
}
llvm.return
}

// CHECK-LABEL: define void @_QPsimd_reduction(
// CHECK: %[[ORIG_I:.*]] = alloca i32, i64 1, align 4
// CHECK: store float 0.000000e+00, ptr %[[ORIG_SUM:.*]], align 4
// CHECK: %[[PRIV_I:.*]] = alloca i32, align 4
// CHECK: %[[RED_VAR:.*]] = alloca float, align 4
// CHECK: br label %[[VAL_4:.*]]
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_5:.*]]
// CHECK: br label %[[VAL_6:.*]]
// CHECK: entry: ; preds = %[[VAL_4]]
// CHECK: br label %[[VAL_7:.*]]
// CHECK: omp.private.init: ; preds = %[[VAL_6]]
// CHECK: br label %[[VAL_8:.*]]
// CHECK: omp.reduction.init: ; preds = %[[VAL_7]]
// CHECK: store float 0.000000e+00, ptr %[[RED_VAR]], align 4
// CHECK: br label %[[VAL_9:.*]]
// CHECK: omp.simd.region: ; preds = %[[VAL_8]]
// CHECK: br label %[[VAL_10:.*]]
// CHECK: omp_loop.preheader: ; preds = %[[VAL_9]]
// CHECK: br label %[[VAL_11:.*]]
// CHECK: omp_loop.header: ; preds = %[[VAL_12:.*]], %[[VAL_10]]
// CHECK: %[[VAL_13:.*]] = phi i32 [ 0, %[[VAL_10]] ], [ %[[VAL_14:.*]], %[[VAL_12]] ]
// CHECK: br label %[[VAL_15:.*]]
// CHECK: omp_loop.cond: ; preds = %[[VAL_11]]
// CHECK: %[[VAL_16:.*]] = icmp ult i32 %[[VAL_13]], 1024
// CHECK: br i1 %[[VAL_16]], label %[[VAL_17:.*]], label %[[VAL_18:.*]]
// CHECK: omp_loop.body: ; preds = %[[VAL_15]]
// CHECK: %[[VAL_19:.*]] = mul i32 %[[VAL_13]], 1
// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1
// CHECK: br label %[[VAL_21:.*]]
// CHECK: omp.loop_nest.region: ; preds = %[[VAL_17]]
// CHECK: store i32 %[[VAL_20]], ptr %[[PRIV_I]], align 4, !llvm.access.group ![[ACCESS_GROUP:.*]]
// CHECK: %[[RED_VAL:.*]] = load float, ptr %[[RED_VAR]], align 4, !llvm.access.group ![[ACCESS_GROUP]]
// CHECK: %[[VAL_23:.*]] = load i32, ptr %[[PRIV_I]], align 4, !llvm.access.group ![[ACCESS_GROUP]]
// CHECK: %[[VAL_24:.*]] = sext i32 %[[VAL_23]] to i64
// CHECK: %[[VAL_25:.*]] = sub nsw i64 %[[VAL_24]], 1
// CHECK: %[[VAL_26:.*]] = getelementptr float, ptr %[[VAL_27:.*]], i64 %[[VAL_25]]
// CHECK: %[[VAL_28:.*]] = load float, ptr %[[VAL_26]], align 4, !llvm.access.group ![[ACCESS_GROUP]]
// CHECK: %[[VAL_29:.*]] = fadd contract float %[[RED_VAL]], %[[VAL_28]]
// CHECK: store float %[[VAL_29]], ptr %[[RED_VAR]], align 4, !llvm.access.group ![[ACCESS_GROUP]]
// CHECK: br label %[[VAL_30:.*]]
// CHECK: omp.region.cont1: ; preds = %[[VAL_21]]
// CHECK: br label %[[VAL_12]]
// CHECK: omp_loop.inc: ; preds = %[[VAL_30]]
// CHECK: %[[VAL_14]] = add nuw i32 %[[VAL_13]], 1
// CHECK: br label %[[VAL_11]], !llvm.loop ![[LOOP:.*]]
// CHECK: omp_loop.exit: ; preds = %[[VAL_15]]
// CHECK: br label %[[VAL_31:.*]]
// CHECK: omp_loop.after: ; preds = %[[VAL_18]]
// CHECK: br label %[[VAL_32:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_31]]
// CHECK: %[[SUM_VAL:.*]] = load float, ptr %[[ORIG_SUM]], align 4
// CHECK: %[[RED_VAL:.*]] = load float, ptr %[[RED_VAR]], align 4
// CHECK: %[[COMBINED_VAL:.*]] = fadd contract float %[[SUM_VAL]], %[[RED_VAL]]
// CHECK: store float %[[COMBINED_VAL]], ptr %[[ORIG_SUM]], align 4
// CHECK: ret void

// CHECK: ![[ACCESS_GROUP]] = distinct !{}
// CHECK: ![[LOOP]] = distinct !{![[LOOP]], ![[PARALLEL_ACCESS:.*]], ![[VECTORIZE:.*]]}
// CHECK: ![[PARALLEL_ACCESS]] = !{!"llvm.loop.parallel_accesses", ![[ACCESS_GROUP]]}
// CHECK: ![[VECTORIZE]] = !{!"llvm.loop.vectorize.enable", i1 true}
30 changes: 0 additions & 30 deletions mlir/test/Target/LLVMIR/openmp-todo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -143,36 +143,6 @@ llvm.func @simd_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {

// -----

omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
%0 = llvm.mlir.constant(0.0 : f32) : f32
omp.yield (%0 : f32)
}
combiner {
^bb1(%arg0: f32, %arg1: f32):
%1 = llvm.fadd %arg0, %arg1 : f32
omp.yield (%1 : f32)
}
atomic {
^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
%2 = llvm.load %arg3 : !llvm.ptr -> f32
llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
omp.yield
}
llvm.func @simd_reduction(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// expected-error@below {{not yet implemented: Unhandled clause reduction in omp.simd operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.simd}}
omp.simd reduction(@add_f32 %x -> %prv : !llvm.ptr) {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
omp.yield
}
}
llvm.return
}

// -----

omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
Expand Down