Skip to content

Commit

Permalink
Enable rsqrt and floor for BF16. (triton-lang#109)
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <ilya.enkovich@intel.com>
  • Loading branch information
ienkovich authored and minjang committed Oct 23, 2024
1 parent d4043d5 commit e64fc27
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 5 deletions.
4 changes: 0 additions & 4 deletions python/test/unit/cpu/test_libdevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ def test_libdevice(dtype_str, math_fn, size, device):
if not is_cpu():
pytest.skip("This test is CPU-specific")

if dtype_str == "bfloat16":
if math_fn == "floor" or math_fn == "rsqrt":
pytest.skip("libgcc < 13 does not define __truncsfbf2, which this op needs")

@triton.jit
def kernel(src, dst, MATH_FN: tl.constexpr, BLOCK_SIZE: tl.constexpr):
idxs = tl.arange(0, BLOCK_SIZE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ struct ConvertUnsupportedOps
patterns.add<PromoteOpToFp32<math::Log10Op>>(context);
patterns.add<PromoteOpToFp32<math::Log1pOp>>(context);
patterns.add<PromoteOpToFp32<math::PowFOp>>(context);
patterns.add<PromoteOpToFp32<math::RsqrtOp>>(context);
patterns.add<PromoteOpToFp32<math::SinOp>>(context);
patterns.add<PromoteOpToFp32<math::SinhOp>>(context);
patterns.add<PromoteOpToFp32<math::SqrtOp>>(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ struct Fp32ToBf16Conversion : public OpRewritePattern<arith::TruncFOp> {
}
};

struct Bf16ToFp32Conversion : public OpRewritePattern<arith::ExtFOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(arith::ExtFOp op,
PatternRewriter &rewriter) const override {
Value src = op.getIn();
if (!isFp32(op.getType()) || !isBf16(src.getType()))
return failure();

Location loc = op.getLoc();
Value i16Src = op_bitcast(toInt16(src.getType()), src);
Value i32Src = op_zext(toInt32(src.getType()), i16Src);
Value i32Res = op_shl(i32Src, cst_like(i32Src, 16));
Value res = op_bitcast(op.getType(), i32Res);
rewriter.replaceOp(op, res);
return success();
}
};

typedef std::function<Value(Location, Value, PatternRewriter &)> FpToFpConvFn;

// Convert FP8 to FP16/FP32.
Expand Down Expand Up @@ -501,8 +520,10 @@ struct DecomposeFpConversions
ModuleOp mod = getOperation();

RewritePatternSet patterns(context);
if (decomposeBf16Conversions)
if (decomposeBf16Conversions) {
patterns.add<Fp32ToBf16Conversion>(context);
patterns.add<Bf16ToFp32Conversion>(context);
}
if (decomposeFp8Conversions) {
patterns.add<RewriteTruncFp8>(context);
patterns.add<RewriteExtFp8>(context);
Expand Down

0 comments on commit e64fc27

Please sign in to comment.