Skip to content

Commit

Permalink
[BACKEND] Use an optimized ptx code sequence for fp4 upcasting (trito…
Browse files Browse the repository at this point in the history
…n-lang#5344)

Integrate code sequence made by @rawnhenry for efficient fp4 upcasting

Co-authored-by: Rawn Henry <rawnhenry@gmail.com>
  • Loading branch information
ThomasRaoux and rawnhenry authored Dec 5, 2024
1 parent 147d332 commit 10552c5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 1 deletion.
17 changes: 17 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2028,3 +2028,20 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shar
}

}

// -----

#linear = #ttg.linear<{register = [], lane = [[0, 1], [1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [16, 0]], block = []}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {

tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, %arg1: tensor<32x2xi8, #linear>) {
// CHECK-LABEL: upcast_mxfp
// CHECK-COUNT-4: llvm.inline_asm
// CHECK-COUNT-2: nvvm.shfl.sync
// CHECK-COUNT-32: llvm.fmul
%0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
tt.return
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "PatternTritonGPUOpToLLVM.h"

#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -19,6 +20,73 @@ using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;

// Convert 8 fp4 elements packed into a 32bit reg into 8 bf16 elements packed
// into 4 32bits regs.
static constexpr const char *ptxAsm =
"{\n"
".reg .b32 a<14>;\n"
"and.b32 a0, $4, -2004318072;\n\t"
"shr.u32 a1, a0, 3;\n\t"
"and.b32 a2, $4, 2004318071;\n\t"
"shr.u32 a3, a2, 16;\n\t"
"shr.u32 a4, a0, 19;\n\t"
"prmt.b32 a5, -1065353216, -1065336832, a2;\n\t"
"prmt.b32 a6, -1065353216, -1065336832, a3;\n\t"
"prmt.b32 a7, 1061109504, 1077952576, a2;\n\t"
"prmt.b32 a8, 1061109504, 1077952576, a3;\n\t"
"prmt.b32 a9, 32768, 0, a1;\n\t"
"prmt.b32 a10, 32768, 0, a4;\n\t"
"or.b32 a11, a7, a9;\n\t"
"or.b32 a12, a8, a10;\n\t"
"prmt.b32 $0, a5, a11, 20800;\n\t"
"prmt.b32 $1, a5, a11, 29538;\n\t"
"prmt.b32 $2, a6, a12, 20800;\n\t"
"prmt.b32 $3, a6, a12, 29538;\n\t"
"}";

static Value createInlineAsmUpcast(Location loc, RewriterBase &rewriter,
Type retType, Value packedVec) {
PTXBuilder builder;
SmallVector<PTXBuilder::Operand *> operands;
for (int i = 0; i < 4; i++) {
operands.push_back(builder.newOperand("=r"));
}
operands.push_back(builder.newOperand(packedVec, "r"));
auto &ptxOp = *builder.create(ptxAsm);
ptxOp(operands, /*onlyAttachMLIRArgs=*/true);
Value result = builder.launch(rewriter, loc, retType, false);
return result;
}

static SmallVector<Value> convertMxfp4x2ToBf16x2PTX(RewriterBase &rewriter,
Location loc,
ArrayRef<Value> values) {
SmallVector<Value> results;
MLIRContext *ctx = rewriter.getContext();
assert(values.size() % 4 == 0);
for (int i = 0; i < values.size(); i += 4) {
Value v0 = values[i];
Value v1 = values[i + 1];
Value v2 = values[i + 2];
Value v3 = values[i + 3];
Value packedVec = undef(vec_ty(i8_ty, 4));
packedVec = insert_element(packedVec, v0, i32_val(0));
packedVec = insert_element(packedVec, v1, i32_val(1));
packedVec = insert_element(packedVec, v2, i32_val(2));
packedVec = insert_element(packedVec, v3, i32_val(3));
SmallVector<Type> rets(4, i32_ty);
Type retType = struct_ty(rets);
Value ret = createInlineAsmUpcast(loc, rewriter, retType, packedVec);
for (int i = 0; i < 4; i++) {
Value extractI32 = extract_val(ret, i);
Value vecbf16 = bitcast(extractI32, vec_ty(bf16_ty, 2));
results.push_back(extract_element(vecbf16, i32_val(0)));
results.push_back(extract_element(vecbf16, i32_val(1)));
}
}
return results;
}

namespace {
class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
private:
Expand Down Expand Up @@ -53,7 +121,7 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern<UpcastMXFPOp> {
cast<DotOperandEncodingAttr>(op.getType().getEncoding()).getKWidth();

if (fpType == ScaleDotElemType::E2M1)
xVals = LLVM::convertMxfp4x2ToBf16x2(rewriter, loc, xVals);
xVals = convertMxfp4x2ToBf16x2PTX(rewriter, loc, xVals);

// Each thread owns elements of 4 mxfp vectors so we need 4 scales
// Since we go from a threadShape of 8x4 to 16x2, we let c = tid / 4 * 2
Expand Down

0 comments on commit 10552c5

Please sign in to comment.