From 781ae0b2f6b26c496054624d807fc64c0ed6594c Mon Sep 17 00:00:00 2001 From: Kyle Wang Date: Fri, 3 Jan 2025 13:26:28 +0800 Subject: [PATCH] [AMD] Flush denorms to zero in math.rsqrt f32 (#5438) This commit modified the denorm handling behavior of math.rsqrt. In case of ftz, it calls llvm.amdgcn.rsq.f32 directly to flush the denormalized inputs to zero. Otherwise, it calls __ocml_rsqrt_f32, which will dynamically check the backend to decide ftz or not. | Arch | non-ftz | ftz | | ---- | ---- | ---- | | CUDA | __nv_rsqrtf => rsqrt.approx.f32 | __nv_rsqrtf => rsqrt.approx.ftz.f32 | | AMD | __ocml_rsqrt_f32 | llvm.amdgcn.rsq.f32 | --- .../TritonNvidiaGPU/Transforms/PlanCTA.cpp | 2 +- test/Conversion/amd/math-denorm-handling.mlir | 12 ++++++ .../ElementwiseOpToLLVM.cpp | 42 +++++++++++++++++++ .../OptimizeEpilogue.cpp | 4 +- 4 files changed, 57 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp index e26af25a6d05..7769520723b9 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp @@ -656,7 +656,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const { math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op, math::FloorOp, math::ExpM1Op, math::FmaOp, math::LogOp, math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp, - math::RsqrtOp, math::SqrtOp, math::RsqrtOp, math::TanhOp>(op)) + math::SqrtOp, math::RsqrtOp, math::TanhOp>(op)) return true; if (llvm::isa +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @test_rsqrt(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} { + // LLVM_FTZ: llvm.amdgcn.rsq.f32 + // LLVM_NO_FTZ: _ocml_rsqrt_f32 + %0 = math.rsqrt %arg0 : tensor<64xf32, #blocked> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index 716a93865ddd..df1e1b7e8859 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -1330,6 +1330,46 @@ struct Exp2OpConversion bool ftz; }; +struct RsqrtOpConversion + : ElementwiseOpConversionBase { + using ElementwiseOpConversionBase< + mlir::math::RsqrtOp, RsqrtOpConversion>::ElementwiseOpConversionBase; + + explicit RsqrtOpConversion(LLVMTypeConverter &typeConverter, + ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz, + PatternBenefit benefit) + : ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit), + ftz(ftz) {} + + SmallVector createDestOps(mlir::math::RsqrtOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter, + Type elemTy, MultipleOperandsRange operands, + Location loc) const { + // This pass only deals with FP32 input with ftz configuration. Other cases + // are delegate to MLIR. + // + // For FP16/FP64 input, it's lowered to __ocml_rsqrt_f16/__ocml_rsqrt_f64. + // + // For FP32 input with non-ftz configuration, it's lowered to + // __ocml_rsqrt_f32, which will check the ftz/daz settings in the backend + // dynamically to decide to preserve/flush denorms. + if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz) + return {}; + + // `llvm.amdgcn.rsq.f32` provides direct access to v_rsq_f32_e32. + StringRef funcName = "llvm.amdgcn.rsq.f32"; + Type funcType = getFunctionType(elemTy, operands[0]); + LLVM::LLVMFuncOp funcOp = + appendOrGetExternFuncOp(rewriter, op, funcName, funcType); + + return { + LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()}; + } + +private: + bool ftz; +}; + } // namespace namespace mlir::triton::AMD { @@ -1371,6 +1411,8 @@ void populateElementwiseOpToLLVMPatterns( // Exp2OpConversion will return failure and later pass will call // __ocml_exp2_f64 for higher-precision calculation patterns.add(typeConverter, axisInfoAnalysis, ftz, benefit); + patterns.add(typeConverter, axisInfoAnalysis, ftz, + benefit); mlir::triton::populateElementwiseOpToLLVMPatterns( typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit); mlir::triton::populateMinMaxFOpToLLVMPattern( diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp index f2818297f5d0..df990b5af20e 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -47,8 +47,8 @@ bool isOneOperandElementwiseOp(Operation *op) { math::CountLeadingZerosOp, math::CountTrailingZerosOp, math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op, math::ExpM1Op, math::FloorOp, math::LogOp, math::Log10Op, - math::Log1pOp, math::Log2Op, math::RsqrtOp, math::SqrtOp, - math::RsqrtOp, math::TanhOp>(op)) + math::Log1pOp, math::Log2Op, math::SqrtOp, math::RsqrtOp, + math::TanhOp>(op)) return true; if (llvm::isa(op))