Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 9 additions & 17 deletions examples/attention_sink/example_gqa_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ def get_bwd_configs():


@tilelang.jit(
out_idx=[3, 4],
pass_configs={
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_fwd(
batch,
heads,
Expand Down Expand Up @@ -140,11 +138,9 @@ def flash_fwd(


@tilelang.jit(
out_idx=[2],
pass_configs={
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
Expand Down Expand Up @@ -180,11 +176,9 @@ def make_dq_layout(dQ):


@tilelang.jit(
out_idx=[1],
pass_configs={
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
Expand All @@ -205,11 +199,9 @@ def flash_bwd_post(
return flash_bwd_post


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(batch,
heads,
seq_len,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,9 @@ def get_configs():
rep=100,
)
@tilelang.jit(
out_idx=[3],
pass_configs={
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn(
batch,
heads,
Expand Down
26 changes: 9 additions & 17 deletions examples/attention_sink/example_mha_sink_bwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@ def get_bwd_configs():


@tilelang.jit(
out_idx=[3, 4],
pass_configs={
out_idx=[3, 4], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_fwd(
batch,
heads,
Expand Down Expand Up @@ -137,11 +135,9 @@ def flash_fwd(


@tilelang.jit(
out_idx=[2],
pass_configs={
out_idx=[2], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
Expand Down Expand Up @@ -177,11 +173,9 @@ def make_dq_layout(dQ):


@tilelang.jit(
out_idx=[1],
pass_configs={
out_idx=[1], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
accum_dtype = "float"
shape = [batch, heads, seq_len, dim]
Expand All @@ -202,11 +196,9 @@ def flash_bwd_post(
return flash_bwd_post


@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
@tilelang.jit(pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
})
def flashattn_bwd(
batch,
heads,
Expand Down
6 changes: 2 additions & 4 deletions examples/attention_sink/example_mha_sink_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ def get_configs():

@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3],
pass_configs={
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn(
batch,
heads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@ def get_configs():

@autotune(configs=get_configs(), warmup=500, rep=100)
@tilelang.jit(
out_idx=[3],
pass_configs={
out_idx=[3], pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},
compile_flags=["-O3", "-DENABLE_BF16"])
})
def flashattn(
batch,
heads,
Expand Down
168 changes: 116 additions & 52 deletions src/target/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,56 +900,123 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << ' ' << sret << ";\n";
std::string src = SSAGetID(PrintExpr(op->value), from_ty);

// Handle bfloat16 special cases with supported ops
bool used_bf16_op = false;
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
std::ostringstream func_name;
if (from_ty.is_bfloat16()) {
func_name << "bf16";
} else if (from_ty.is_float()) {
func_name << "float";
}
if (from_ty.lanes() > 1) {
func_name << from_ty.lanes();
}
func_name << "2";
if (target_ty.is_bfloat16()) {
func_name << "bf16";
} else if (target_ty.is_float()) {
func_name << "float";
} else if (target_ty == DataType::Int(16)) {
func_name << "int16";
}
if (target_ty.lanes() > 1) {
func_name << target_ty.lanes();
}

auto fname = func_name.str();
if (bf16_supported_ops_.count(fname)) {
used_bf16_op = true;
stream << "#ifdef ENABLE_BF16\n";
// Handle conversion between float16 and float32
if (from_ty.is_float16() && target_ty.is_float()) {
// Use __half22float2 for vectorized conversion (half2 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// half2 -> float2
PrintIndent();
stream << "reinterpret_cast<";
if (target_ty.is_bfloat16()) {
stream << "__nv_bfloat16";
} else {
PrintType(target_ty.element_of(), stream);
}
if (target_ty.lanes() > 1) {
stream << target_ty.lanes();
}
stream << " &>(" << sret << ") = fastertransformer::" << fname
<< "(reinterpret_cast<";
if (from_ty.is_bfloat16()) {
stream << "__nv_bfloat16";
} else {
PrintType(from_ty.element_of(), stream);
}
if (from_ty.lanes() > 1) {
stream << from_ty.lanes();
}
stream << " const &>(" << src << "));\n";
stream << "#else\n";
stream << sret << " = __half22float2(*(half2*)(&(" << src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// half4 -> float4
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__half22float2(*(half2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__half22float2(*((half2*)(&(" << src << "))+1));\n";
os << sret;
return;
}
} else if (from_ty.is_float() && target_ty.is_float16()) {
// Use __float22half2_rn for vectorized conversion (float2 -> half2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// float2 -> half2
PrintIndent();
stream << "*(half2*)(&(" << sret << ")) = __float22half2_rn(*(float2*)(&("
<< src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// float4 -> half4
PrintIndent();
stream << "((half2*)(&" << sret << "))[0] = "
<< "__float22half2_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "((half2*)(&" << sret << "))[1] = "
<< "__float22half2_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
}
}

// Handle conversion between bfloat16 and float32
if (from_ty.is_bfloat16() && target_ty.is_float()) {
// Use __bfloat1622float2 for vectorized conversion (bfloat162 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// bfloat162 -> float2
PrintIndent();
stream << sret
<< " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<< src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// bfloat162x2 -> float4
PrintIndent();
stream << "((float2*)(&" << sret << "))[0] = "
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
<< src << ")));\n";
PrintIndent();
stream << "((float2*)(&" << sret << "))[1] = "
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
<< src << "))+1));\n";
os << sret;
return;
}
} else if (from_ty.is_float() && target_ty.is_bfloat16()) {
// Use __float22bfloat162_rn for vectorized conversion (float2 -> bfloat162)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// float2 -> bfloat162
PrintIndent();
stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret
<< ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// float4 -> bfloat162x2
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[0] = "
<< "__float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
PrintIndent();
stream << "(reinterpret_cast<__nv_bfloat162*>(&" << sret << "))[1] = "
<< "__float22bfloat162_rn(*((float2*)(&(" << src << "))+1));\n";
os << sret;
return;
}
}

// Handle conversion from float32 to float8 (E4M3/E5M2)
if (from_ty.is_float() &&
(target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) {
// FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion
// (float2 -> fp8x2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
// float2 -> fp8x2
PrintIndent();
stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret
<< ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&("
<< src << ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
// float4 -> fp8x4
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
}
}

Expand All @@ -964,9 +1031,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
PrintVecElemStore(sret, target_ty, i, val.str());
}

if (used_bf16_op) {
stream << "#endif\n";
}
os << sret;
}

Expand Down
Loading
Loading