Skip to content

Commit 4bf2524

Browse files
committed
Merge remote-tracking branch 'upstream/main' into tvm_rebase
2 parents 9ef647c + f7ba45d commit 4bf2524

File tree

12 files changed

+1297
-297
lines changed

12 files changed

+1297
-297
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import tilelang.language as T
2+
from typing import Literal, Callable
3+
from tvm.tir import IndexMap
4+
from tilelang.intrinsics.utils import get_mma_micro_size
5+
6+
from tilelang.intrinsics.mfma_layout import (
7+
shared_16x4_to_local_64x1_layout_A,
8+
shared_16x16_to_local_64x4_layout_A,
9+
shared_16x32_to_local_64x8_layout_A,
10+
shared_16x64_to_local_64x16_layout_A,
11+
)
12+
13+
14+
def make_mfma_load_base_layout(dtype: str = "float16",
15+
matrix: Literal["A", "B"] = "A",
16+
k_dim: int = 16,
17+
transposed: bool = False) -> T.Fragment:
18+
"""
19+
Create a layout function for storing MFMA results into a fragment buffer.
20+
This layout is used in conjunction with `inverse_mfma_store_layout` to
21+
map fragment indices to threads and local indices.
22+
23+
Parameters
24+
----------
25+
dtype : str
26+
The data type of the matrix.
27+
matrix : Literal["A", "B"]
28+
The mfma operand to be loaded.
29+
k_dim : int
30+
The k dimension of the mfma.
31+
transposed : bool
32+
Whether the matrix is transposed, by default False.
33+
34+
Returns
35+
-------
36+
T.Fragment
37+
Describes how threads and indices in fragment are laid out.
38+
39+
"""
40+
41+
assert matrix in ["A", "B"], "matrix should be either A or B"
42+
# s represents spatial axis
43+
# r represents reduction axis
44+
# sr represents the two dims are spatial + reduction
45+
# rs represents the two dims are reduction + spatial
46+
transform_func_sr_a: Callable = None
47+
transform_func_sr_b: Callable = None
48+
49+
if k_dim == 4:
50+
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
51+
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
52+
elif k_dim == 16:
53+
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
54+
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
55+
elif k_dim == 32:
56+
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
57+
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
58+
elif k_dim == 64:
59+
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
60+
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
61+
else:
62+
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
63+
64+
is_sr_conditions = [False]
65+
is_sr_conditions.append(matrix == "A" and not transposed)
66+
is_sr_conditions.append(matrix == "B" and transposed)
67+
is_sr_axis_order = any(is_sr_conditions)
68+
69+
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
70+
71+
# the layout of mma.sync is row.col.
72+
# so the b matrix expected a transposed basic layout
73+
transform_func: Callable = None
74+
if matrix == "A":
75+
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
76+
j, i)
77+
micro_size_s, micro_size_r = micro_size_x, micro_size_k
78+
elif matrix == "B":
79+
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
80+
j, i)
81+
micro_size_s, micro_size_r = micro_size_k, micro_size_y
82+
else:
83+
raise ValueError(f"Unsupported matrix {matrix}")
84+
85+
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
86+
87+
def forward_thread(i: int, j: int) -> int:
88+
"""
89+
Given the row index `i` and column index `j` in the fragment,
90+
"""
91+
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
92+
return lane_id
93+
94+
def forward_index(i: int, j: int) -> int:
95+
"""
96+
Given the row index `i` and column index `j` in the fragment,
97+
"""
98+
_, local_id = inverse_mma_load_layout.map_indices([i, j])
99+
return local_id
100+
101+
base_fragment = T.Fragment(
102+
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
103+
forward_thread_fn=forward_thread,
104+
forward_index_fn=forward_index,
105+
)
106+
return base_fragment
107+
108+
109+
block_rows = 2
110+
block_cols = 2
111+
warp_rows = 2
112+
warp_cols = 2
113+
chunk = 2
114+
115+
from tilelang.tools import plot_layout
116+
117+
# ldmatrix layout 16x16
118+
base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False)
119+
print(base_layout)
120+
plot_layout(base_layout, name="base_layout")
121+
122+
# warp layout 32x32
123+
warp_layout = base_layout.repeat([warp_rows, warp_cols],
124+
repeat_on_thread=False,
125+
lower_dim_first=False)
126+
print(warp_layout)
127+
plot_layout(warp_layout, name="warp_layout")
128+
129+
# block layout 64x32
130+
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True,
131+
lower_dim_first=True).replicate(block_cols)
132+
print(block_layout)
133+
plot_layout(block_layout, name="block_layout")

src/target/codegen_cuda.cc

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,6 +1018,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
10181018
<< "))+1), __NV_SATFINITE, "
10191019
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
10201020
<< ");\n";
1021+
os << sret;
1022+
return;
10211023
}
10221024
}
10231025

@@ -1035,6 +1037,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
10351037
os << sret;
10361038
}
10371039

1040+
void CodeGenTileLangCUDA::VisitExpr_(const MinNode *op, std::ostream &os) {
1041+
// TODO(wt): Consider vectorized reduction and impl for other dtypes
1042+
DataType t = op->dtype;
1043+
1044+
// Standard min/max functions don't support bfloat16 or float16
1045+
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
1046+
os << "cutlass::fast_min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
1047+
<< ")";
1048+
return;
1049+
}
1050+
1051+
// For float32 and float64 scalar, use standard min functions
1052+
if (t.is_float() && t.is_scalar()) {
1053+
if (t.bits() == 32 || t.bits() == 64) {
1054+
os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
1055+
return;
1056+
}
1057+
}
1058+
1059+
// For all other scalar types (int, uint), use default implementation
1060+
CodeGenC::VisitExpr_(op, os);
1061+
}
1062+
1063+
void CodeGenTileLangCUDA::VisitExpr_(const MaxNode *op, std::ostream &os) {
1064+
// TODO(wt): Consider vectorized reduction and impl for other dtypes
1065+
DataType t = op->dtype;
1066+
1067+
// Standard min/max functions don't support bfloat16 or float16
1068+
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
1069+
os << "cutlass::fast_max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
1070+
<< ")";
1071+
return;
1072+
}
1073+
1074+
// For float32 and float64 scalar, use standard max functions
1075+
if (t.is_float() && t.is_scalar()) {
1076+
if (t.bits() == 32 || t.bits() == 64) {
1077+
os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
1078+
return;
1079+
}
1080+
}
1081+
1082+
// For all other scalar types (int, uint), use default implementation
1083+
CodeGenC::VisitExpr_(op, os);
1084+
}
1085+
10381086
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
10391087
const Array<PrimExpr> &args,
10401088
bool skip_first_arg,
@@ -2541,12 +2589,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
25412589

25422590
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
25432591
CodeGenTileLangCUDA *p) { // NOLINT(*)
2544-
// Type code is kBFloat
2545-
if (op->dtype.is_bfloat16()) {
2546-
os << "bfloat16_t";
2547-
os << '(' << std::hexfloat << op->value << 'f';
2548-
os << "/*" << std::scientific << op->value << "*/";
2549-
os << ')';
2592+
// Type code is kBFloat/kFloat16
2593+
// which is indeed CUTLASS supported types currently
2594+
if (op->dtype.is_bfloat16() || op->dtype.is_float16()) {
2595+
std::ostringstream temp;
2596+
if (std::isinf(op->value)) {
2597+
if (op->value < 0) {
2598+
temp << "-";
2599+
}
2600+
temp << "std::numeric_limits<";
2601+
p->PrintType(op->dtype, temp);
2602+
temp << ">::infinity()";
2603+
} else if (std::isnan(op->value)) {
2604+
temp << "std::numeric_limits<";
2605+
p->PrintType(op->dtype, temp);
2606+
temp << ">::quiet_NaN()";
2607+
} else {
2608+
p->PrintType(op->dtype, temp);
2609+
temp << '(' << std::hexfloat << op->value << 'f';
2610+
temp << "/*" << std::scientific << op->value << "*/";
2611+
temp << ')';
2612+
}
2613+
p->MarkConst(temp.str());
2614+
os << temp.str();
25502615
return;
25512616
}
25522617
// Type code is kFloat8_e5m2 or kE4M4Float
@@ -2557,7 +2622,7 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
25572622
os << ')';
25582623
return;
25592624
}
2560-
// Type code is kFloat
2625+
// Type code is kFloat64/kFloat32 (kFloat16 is handled above)
25612626
switch (op->dtype.bits()) {
25622627
case 64:
25632628
case 32: {
@@ -2581,13 +2646,6 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
25812646
os << temp.str();
25822647
break;
25832648
}
2584-
case 16: {
2585-
os << "half_t" << '(';
2586-
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
2587-
PrintConst(const_f32.get(), os, p);
2588-
os << ')';
2589-
break;
2590-
}
25912649
default:
25922650
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
25932651
}

src/target/codegen_cuda.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class CodeGenTileLangCUDA final : public CodeGenC {
5151
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
5252
void VisitExpr_(const CallNode *op, std::ostream &os) final;
5353
void VisitExpr_(const CastNode *op, std::ostream &os) final;
54+
void VisitExpr_(const MinNode *op, std::ostream &os) final;
55+
void VisitExpr_(const MaxNode *op, std::ostream &os) final;
5456
void VisitStmt_(const EvaluateNode *op) final;
5557
void VisitStmt_(const AllocateNode *op) final;
5658
void VisitStmt_(const AttrStmtNode *op) final;

0 commit comments

Comments
 (0)