Skip to content

Commit

Permalink
[AMD] Flush denorms to zero in math.rsqrt f32 (#5438)
Browse files Browse the repository at this point in the history
This commit modified the denorm handling behavior of math.rsqrt.

In case of ftz, it calls llvm.amdgcn.rsq.f32 directly to flush the
denormalized inputs to zero. Otherwise, it calls __ocml_rsqrt_f32, which
will dynamically check the backend to decide ftz or not.

| Arch | non-ftz | ftz |
| ---- | ---- | ---- |
| CUDA | __nv_rsqrtf => rsqrt.approx.f32 | __nv_rsqrtf => rsqrt.approx.ftz.f32 |
| AMD  | __ocml_rsqrt_f32 | llvm.amdgcn.rsq.f32 |
  • Loading branch information
knwng authored Jan 3, 2025
1 parent 37817d7 commit 781ae0b
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
2 changes: 1 addition & 1 deletion lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
math::FloorOp, math::ExpM1Op, math::FmaOp, math::LogOp,
math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp,
math::RsqrtOp, math::SqrtOp, math::RsqrtOp, math::TanhOp>(op))
math::SqrtOp, math::RsqrtOp, math::TanhOp>(op))
return true;
if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
triton::FpToFpOp, triton::AddPtrOp, triton::PreciseSqrtOp,
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/amd/math-denorm-handling.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,15 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
tt.return
}
}

// -----

#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
tt.func public @test_rsqrt(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
// LLVM_FTZ: llvm.amdgcn.rsq.f32
// LLVM_NO_FTZ: _ocml_rsqrt_f32
%0 = math.rsqrt %arg0 : tensor<64xf32, #blocked>
tt.return
}
}
42 changes: 42 additions & 0 deletions third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1330,6 +1330,46 @@ struct Exp2OpConversion
bool ftz;
};

struct RsqrtOpConversion
: ElementwiseOpConversionBase<mlir::math::RsqrtOp, RsqrtOpConversion> {
using ElementwiseOpConversionBase<
mlir::math::RsqrtOp, RsqrtOpConversion>::ElementwiseOpConversionBase;

explicit RsqrtOpConversion(LLVMTypeConverter &typeConverter,
ModuleAxisInfoAnalysis &axisInfoAnalysis, bool ftz,
PatternBenefit benefit)
: ElementwiseOpConversionBase(typeConverter, axisInfoAnalysis, benefit),
ftz(ftz) {}

SmallVector<Value> createDestOps(mlir::math::RsqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Type elemTy, MultipleOperandsRange operands,
Location loc) const {
// This pass only deals with FP32 input with ftz configuration. Other cases
// are delegate to MLIR.
//
// For FP16/FP64 input, it's lowered to __ocml_rsqrt_f16/__ocml_rsqrt_f64.
//
// For FP32 input with non-ftz configuration, it's lowered to
// __ocml_rsqrt_f32, which will check the ftz/daz settings in the backend
// dynamically to decide to preserve/flush denorms.
if (elemTy.getIntOrFloatBitWidth() != 32 || !ftz)
return {};

// `llvm.amdgcn.rsq.f32` provides direct access to v_rsq_f32_e32.
StringRef funcName = "llvm.amdgcn.rsq.f32";
Type funcType = getFunctionType(elemTy, operands[0]);
LLVM::LLVMFuncOp funcOp =
appendOrGetExternFuncOp(rewriter, op, funcName, funcType);

return {
LLVM::createLLVMCallOp(rewriter, loc, funcOp, operands[0]).getResult()};
}

private:
bool ftz;
};

} // namespace

namespace mlir::triton::AMD {
Expand Down Expand Up @@ -1371,6 +1411,8 @@ void populateElementwiseOpToLLVMPatterns(
// Exp2OpConversion will return failure and later pass will call
// __ocml_exp2_f64 for higher-precision calculation
patterns.add<Exp2OpConversion>(typeConverter, axisInfoAnalysis, ftz, benefit);
patterns.add<RsqrtOpConversion>(typeConverter, axisInfoAnalysis, ftz,
benefit);
mlir::triton::populateElementwiseOpToLLVMPatterns(
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
mlir::triton::populateMinMaxFOpToLLVMPattern(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ bool isOneOperandElementwiseOp(Operation *op) {
math::CountLeadingZerosOp, math::CountTrailingZerosOp,
math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
math::ExpM1Op, math::FloorOp, math::LogOp, math::Log10Op,
math::Log1pOp, math::Log2Op, math::RsqrtOp, math::SqrtOp,
math::RsqrtOp, math::TanhOp>(op))
math::Log1pOp, math::Log2Op, math::SqrtOp, math::RsqrtOp,
math::TanhOp>(op))
return true;
if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
triton::FpToFpOp>(op))
Expand Down

0 comments on commit 781ae0b

Please sign in to comment.