From 71838b93bd5fd358a95e3a965fc95c52a931072d Mon Sep 17 00:00:00 2001 From: KristofferC Date: Thu, 24 Oct 2024 11:52:53 +0200 Subject: [PATCH] Revert "Remove llvm-muladd pass and move it's functionality to to llvm-simdloop (#55802)" This reverts commit 69ed5fdcdc59f345692fd6b0419a456aa4c0c907. --- doc/src/devdocs/llvm-passes.md | 12 +++ doc/src/devdocs/llvm.md | 1 + src/Makefile | 2 +- src/llvm-muladd.cpp | 117 ++++++++++++++++++++++++++++++ src/llvm-simdloop.cpp | 66 ----------------- src/passes.h | 11 +-- src/pipeline.cpp | 1 + test/llvmpasses/julia-simdloop.ll | 21 ------ test/llvmpasses/muladd.ll | 64 ++++++++++++++++ test/llvmpasses/parsing.ll | 2 +- 10 files changed, 201 insertions(+), 96 deletions(-) create mode 100644 src/llvm-muladd.cpp create mode 100644 test/llvmpasses/muladd.ll diff --git a/doc/src/devdocs/llvm-passes.md b/doc/src/devdocs/llvm-passes.md index 736faf54c219b..36383acaef512 100644 --- a/doc/src/devdocs/llvm-passes.md +++ b/doc/src/devdocs/llvm-passes.md @@ -114,6 +114,18 @@ This pass is used to verify Julia's invariants about LLVM IR. This includes thin These passes are used to perform transformations on LLVM IR that LLVM will not perform itself, e.g. fast math flag propagation, escape analysis, and optimizations on Julia-specific internal functions. They use knowledge about Julia's semantics to perform these optimizations. +### CombineMulAdd + +* Filename: `llvm-muladd.cpp` +* Class Name: `CombineMulAddPass` +* Opt Name: `function(CombineMulAdd)` + +This pass serves to optimize the particular combination of a regular `fmul` with a fast `fadd` into a contract `fmul` with a fast `fadd`. This is later optimized by the backend to a [fused multiply-add](https://en.wikipedia.org/wiki/Multiply%E2%80%93accumulate_operation#Fused_multiply%E2%80%93add) instruction, which can provide significantly faster operations at the cost of more [unpredictable semantics](https://simonbyrne.github.io/notes/fastmath/). + +!!! note + + This optimization only occurs when the `fmul` has a single use, which is the fast `fadd`. + ### AllocOpt * Filename: `llvm-alloc-opt.cpp` diff --git a/doc/src/devdocs/llvm.md b/doc/src/devdocs/llvm.md index 8884e7c91f2bf..170a812c09994 100644 --- a/doc/src/devdocs/llvm.md +++ b/doc/src/devdocs/llvm.md @@ -30,6 +30,7 @@ The code for lowering Julia AST to LLVM IR or interpreting it directly is in dir | `llvm-julia-licm.cpp` | Custom LLVM pass to hoist/sink Julia-specific intrinsics | | `llvm-late-gc-lowering.cpp` | Custom LLVM pass to root GC-tracked values | | `llvm-lower-handlers.cpp` | Custom LLVM pass to lower try-catch blocks | +| `llvm-muladd.cpp` | Custom LLVM pass for fast-match FMA | | `llvm-multiversioning.cpp` | Custom LLVM pass to generate sysimg code on multiple architectures | | `llvm-propagate-addrspaces.cpp` | Custom LLVM pass to canonicalize addrspaces | | `llvm-ptls.cpp` | Custom LLVM pass to lower TLS operations | diff --git a/src/Makefile b/src/Makefile index 46e4fe2fb5532..bf9001e5fba93 100644 --- a/src/Makefile +++ b/src/Makefile @@ -52,7 +52,7 @@ RT_LLVMLINK := CG_LLVMLINK := ifeq ($(JULIACODEGEN),LLVM) -CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop \ +CODEGEN_SRCS := codegen jitlayers aotcompile debuginfo disasm llvm-simdloop llvm-muladd \ llvm-final-gc-lowering llvm-pass-helpers llvm-late-gc-lowering llvm-ptls \ llvm-lower-handlers llvm-gc-invariant-verifier llvm-propagate-addrspaces \ llvm-multiversioning llvm-alloc-opt llvm-alloc-helpers cgmemmgr llvm-remove-addrspaces \ diff --git a/src/llvm-muladd.cpp b/src/llvm-muladd.cpp new file mode 100644 index 0000000000000..12f1c8ad765d9 --- /dev/null +++ b/src/llvm-muladd.cpp @@ -0,0 +1,117 @@ +// This file is a part of Julia. License is MIT: https://julialang.org/license + +#include "llvm-version.h" +#include "passes.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "julia.h" +#include "julia_assert.h" + +#define DEBUG_TYPE "combine-muladd" +#undef DEBUG + +using namespace llvm; +STATISTIC(TotalContracted, "Total number of multiplies marked for FMA"); + +#ifndef __clang_gcanalyzer__ +#define REMARK(remark) ORE.emit(remark) +#else +#define REMARK(remark) (void) 0; +#endif + +/** + * Combine + * ``` + * %v0 = fmul ... %a, %b + * %v = fadd contract ... %v0, %c + * ``` + * to + * `%v = call contract @llvm.fmuladd.<...>(... %a, ... %b, ... %c)` + * when `%v0` has no other use + */ + +// Return true if we changed the mulOp +static bool checkCombine(Value *maybeMul, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT +{ + auto mulOp = dyn_cast(maybeMul); + if (!mulOp || mulOp->getOpcode() != Instruction::FMul) + return false; + if (!mulOp->hasOneUse()) { + LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n"); + REMARK([&](){ + return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp) + << "fmul had multiple uses " << ore::NV("fmul", mulOp); + }); + return false; + } + // On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us. + auto fmf = mulOp->getFastMathFlags(); + if (!fmf.allowContract()) { + LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n"); + REMARK([&](){ + return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp) + << "marked for fma " << ore::NV("fmul", mulOp); + }); + ++TotalContracted; + fmf.setAllowContract(true); + mulOp->copyFastMathFlags(fmf); + return true; + } + return false; +} + +static bool combineMulAdd(Function &F) JL_NOTSAFEPOINT +{ + OptimizationRemarkEmitter ORE(&F); + bool modified = false; + for (auto &BB: F) { + for (auto it = BB.begin(); it != BB.end();) { + auto &I = *it; + it++; + switch (I.getOpcode()) { + case Instruction::FAdd: { + if (!I.hasAllowContract()) + continue; + modified |= checkCombine(I.getOperand(0), ORE) || checkCombine(I.getOperand(1), ORE); + break; + } + case Instruction::FSub: { + if (!I.hasAllowContract()) + continue; + modified |= checkCombine(I.getOperand(0), ORE) || checkCombine(I.getOperand(1), ORE); + break; + } + default: + break; + } + } + } +#ifdef JL_VERIFY_PASSES + assert(!verifyLLVMIR(F)); +#endif + return modified; +} + +PreservedAnalyses CombineMulAddPass::run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT +{ + if (combineMulAdd(F)) { + return PreservedAnalyses::allInSet(); + } + return PreservedAnalyses::all(); +} diff --git a/src/llvm-simdloop.cpp b/src/llvm-simdloop.cpp index 66571f1383a22..f29802b438e1e 100644 --- a/src/llvm-simdloop.cpp +++ b/src/llvm-simdloop.cpp @@ -41,7 +41,6 @@ STATISTIC(ReductionChainLength, "Total sum of instructions folded from reduction STATISTIC(MaxChainLength, "Max length of reduction chain"); STATISTIC(AddChains, "Addition reduction chains"); STATISTIC(MulChains, "Multiply reduction chains"); -STATISTIC(TotalContracted, "Total number of multiplies marked for FMA"); #ifndef __clang_gcanalyzer__ #define REMARK(remark) ORE.emit(remark) @@ -50,49 +49,6 @@ STATISTIC(TotalContracted, "Total number of multiplies marked for FMA"); #endif namespace { -/** - * Combine - * ``` - * %v0 = fmul ... %a, %b - * %v = fadd contract ... %v0, %c - * ``` - * to - * %v0 = fmul contract ... %a, %b - * %v = fadd contract ... %v0, %c - * when `%v0` has no other use - */ - -static bool checkCombine(Value *maybeMul, Loop &L, OptimizationRemarkEmitter &ORE) JL_NOTSAFEPOINT -{ - auto mulOp = dyn_cast(maybeMul); - if (!mulOp || mulOp->getOpcode() != Instruction::FMul) - return false; - if (!L.contains(mulOp)) - return false; - if (!mulOp->hasOneUse()) { - LLVM_DEBUG(dbgs() << "mulOp has multiple uses: " << *maybeMul << "\n"); - REMARK([&](){ - return OptimizationRemarkMissed(DEBUG_TYPE, "Multiuse FMul", mulOp) - << "fmul had multiple uses " << ore::NV("fmul", mulOp); - }); - return false; - } - // On 5.0+ we only need to mark the mulOp as contract and the backend will do the work for us. - auto fmf = mulOp->getFastMathFlags(); - if (!fmf.allowContract()) { - LLVM_DEBUG(dbgs() << "Marking mulOp for FMA: " << *maybeMul << "\n"); - REMARK([&](){ - return OptimizationRemark(DEBUG_TYPE, "Marked for FMA", mulOp) - << "marked for fma " << ore::NV("fmul", mulOp); - }); - ++TotalContracted; - fmf.setAllowContract(true); - mulOp->copyFastMathFlags(fmf); - return true; - } - return false; -} - static unsigned getReduceOpcode(Instruction *J, Instruction *operand) JL_NOTSAFEPOINT { switch (J->getOpcode()) { @@ -194,28 +150,6 @@ static void enableUnsafeAlgebraIfReduction(PHINode *Phi, Loop &L, OptimizationRe }); (*K)->setHasAllowReassoc(true); (*K)->setHasAllowContract(true); - switch ((*K)->getOpcode()) { - case Instruction::FAdd: { - if (!(*K)->hasAllowContract()) - continue; - // (*K)->getOperand(0)->print(dbgs()); - // (*K)->getOperand(1)->print(dbgs()); - checkCombine((*K)->getOperand(0), L, ORE); - checkCombine((*K)->getOperand(1), L, ORE); - break; - } - case Instruction::FSub: { - if (!(*K)->hasAllowContract()) - continue; - // (*K)->getOperand(0)->print(dbgs()); - // (*K)->getOperand(1)->print(dbgs()); - checkCombine((*K)->getOperand(0), L, ORE); - checkCombine((*K)->getOperand(1), L, ORE); - break; - } - default: - break; - } if (SE) SE->forgetValue(*K); ++length; diff --git a/src/passes.h b/src/passes.h index 1bb04816af641..6557a5813063d 100644 --- a/src/passes.h +++ b/src/passes.h @@ -15,16 +15,13 @@ struct DemoteFloat16Pass : PassInfoMixin { static bool isRequired() { return true; } }; -struct LateLowerGCPass : PassInfoMixin { +struct CombineMulAddPass : PassInfoMixin { PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT; - static bool isRequired() { return true; } }; -struct CombineMulAddPass : PassInfoMixin { - PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT { - // no-op - return PreservedAnalyses::all(); - } +struct LateLowerGCPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM) JL_NOTSAFEPOINT; + static bool isRequired() { return true; } }; struct AllocOptPass : PassInfoMixin { diff --git a/src/pipeline.cpp b/src/pipeline.cpp index 54b36daad9a0e..5c12e3dad0dd7 100644 --- a/src/pipeline.cpp +++ b/src/pipeline.cpp @@ -577,6 +577,7 @@ static void buildCleanupPipeline(ModulePassManager &MPM, PassBuilder *PB, Optimi if (options.cleanup) { if (O.getSpeedupLevel() >= 2) { FunctionPassManager FPM; + JULIA_PASS(FPM.addPass(CombineMulAddPass())); FPM.addPass(DivRemPairsPass()); MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); } diff --git a/test/llvmpasses/julia-simdloop.ll b/test/llvmpasses/julia-simdloop.ll index a6e0ac03439fa..df96e34979a3d 100644 --- a/test/llvmpasses/julia-simdloop.ll +++ b/test/llvmpasses/julia-simdloop.ll @@ -61,26 +61,6 @@ loopdone: ret double %nextv } -; CHECK-LABEL: @simd_test_sub4( -define double @simd_test_sub4(double *%a) { -top: - br label %loop -loop: - %i = phi i64 [0, %top], [%nexti, %loop] - %v = phi double [0.000000e+00, %top], [%nextv, %loop] - %aptr = getelementptr double, double *%a, i64 %i - %aval = load double, double *%aptr - %nextv2 = fmul double %aval, %aval - ; CHECK: fmul contract double %aval, %aval - %nextv = fsub double %v, %nextv2 -; CHECK: fsub reassoc contract double %v, %nextv2 - %nexti = add i64 %i, 1 - %done = icmp sgt i64 %nexti, 500 - br i1 %done, label %loopdone, label %loop, !llvm.loop !0 -loopdone: - ret double %nextv -} - ; Tests if we correctly pass through other metadata ; CHECK-LABEL: @disabled( define i32 @disabled(i32* noalias nocapture %a, i32* noalias nocapture readonly %b, i32 %N) { @@ -104,7 +84,6 @@ for.end: ; preds = %for.body ret i32 %1 } - !0 = distinct !{!0, !"julia.simdloop"} !1 = distinct !{!1, !"julia.simdloop", !"julia.ivdep"} !2 = distinct !{!2, !"julia.simdloop", !"julia.ivdep", !3} diff --git a/test/llvmpasses/muladd.ll b/test/llvmpasses/muladd.ll new file mode 100644 index 0000000000000..3c1c995ce7376 --- /dev/null +++ b/test/llvmpasses/muladd.ll @@ -0,0 +1,64 @@ +; This file is a part of Julia. License is MIT: https://julialang.org/license + +; RUN: opt -enable-new-pm=1 --opaque-pointers=0 --load-pass-plugin=libjulia-codegen%shlibext -passes='CombineMulAdd' -S %s | FileCheck %s + +; RUN: opt -enable-new-pm=1 --opaque-pointers=1 --load-pass-plugin=libjulia-codegen%shlibext -passes='CombineMulAdd' -S %s | FileCheck %s + + +; CHECK-LABEL: @fast_muladd1 +define double @fast_muladd1(double %a, double %b, double %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul double %a, %b + %v2 = fadd fast double %v1, %c +; CHECK: ret double + ret double %v2 +} + +; CHECK-LABEL: @fast_mulsub1 +define double @fast_mulsub1(double %a, double %b, double %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul double %a, %b + %v2 = fsub fast double %v1, %c +; CHECK: ret double + ret double %v2 +} + +; CHECK-LABEL: @fast_mulsub_vec1 +define <2 x double> @fast_mulsub_vec1(<2 x double> %a, <2 x double> %b, <2 x double> %c) { +top: +; CHECK: {{contract|fmuladd}} + %v1 = fmul <2 x double> %a, %b + %v2 = fsub fast <2 x double> %c, %v1 +; CHECK: ret <2 x double> + ret <2 x double> %v2 +} + +; COM: Should not mark fmul as contract when multiple uses of fmul exist +; CHECK-LABEL: @slow_muladd1 +define double @slow_muladd1(double %a, double %b, double %c) { +top: +; CHECK: %v1 = fmul double %a, %b + %v1 = fmul double %a, %b +; CHECK: %v2 = fadd fast double %v1, %c + %v2 = fadd fast double %v1, %c +; CHECK: %v3 = fadd fast double %v1, %b + %v3 = fadd fast double %v1, %b +; CHECK: %v4 = fadd fast double %v3, %v2 + %v4 = fadd fast double %v3, %v2 +; CHECK: ret double %v4 + ret double %v4 +} + +; COM: Should not mark fadd->fadd fast as contract +; CHECK-LABEL: @slow_addadd1 +define double @slow_addadd1(double %a, double %b, double %c) { +top: +; CHECK: %v1 = fadd double %a, %b + %v1 = fadd double %a, %b +; CHECK: %v2 = fadd fast double %v1, %c + %v2 = fadd fast double %v1, %c +; CHECK: ret double %v2 + ret double %v2 +} diff --git a/test/llvmpasses/parsing.ll b/test/llvmpasses/parsing.ll index b8aec5ee2fa71..e0a726176b225 100644 --- a/test/llvmpasses/parsing.ll +++ b/test/llvmpasses/parsing.ll @@ -1,6 +1,6 @@ ; COM: NewPM-only test, tests for ability to parse Julia passes -; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='module(CPUFeatures,RemoveNI,JuliaMultiVersioning,RemoveJuliaAddrspaces,LowerPTLSPass,function(DemoteFloat16,LateLowerGCFrame,FinalLowerGC,AllocOpt,PropagateJuliaAddrspaces,LowerExcHandlers,GCInvariantVerifier,loop(LowerSIMDLoop,JuliaLICM),GCInvariantVerifier,GCInvariantVerifier),LowerPTLSPass,LowerPTLSPass,JuliaMultiVersioning,JuliaMultiVersioning)' -S %s -o /dev/null +; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes='module(CPUFeatures,RemoveNI,JuliaMultiVersioning,RemoveJuliaAddrspaces,LowerPTLSPass,function(DemoteFloat16,CombineMulAdd,LateLowerGCFrame,FinalLowerGC,AllocOpt,PropagateJuliaAddrspaces,LowerExcHandlers,GCInvariantVerifier,loop(LowerSIMDLoop,JuliaLICM),GCInvariantVerifier,GCInvariantVerifier),LowerPTLSPass,LowerPTLSPass,JuliaMultiVersioning,JuliaMultiVersioning)' -S %s -o /dev/null ; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia" -S %s -o /dev/null ; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia" -S %s -o /dev/null ; RUN: opt --load-pass-plugin=libjulia-codegen%shlibext -passes="julia" -S %s -o /dev/null