Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions examples/dequantize_gemm/test_example_dequantize_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import example_dequant_gemv_fp16xint4
import example_dequant_gemm_fp4_hopper
import example_dequant_gemm_bf16_fp4_hopper_serial


@tilelang.testing.requires_cuda
Expand All @@ -16,11 +15,5 @@ def test_example_dequant_gemm_fp4_hopper():
example_dequant_gemm_fp4_hopper.main()


@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_dequant_gemm_bf16_fp4_hopper_serial():
example_dequant_gemm_bf16_fp4_hopper_serial.main()


if __name__ == "__main__":
tilelang.testing.main()
1 change: 0 additions & 1 deletion examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.backends
from tilelang import tvm as tvm
import tilelang.testing
from tvm import DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
num_split = 1

kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
print(kernel.get_kernel_source())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. It should be removed before merging.

profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def main():

# Run the kernel through the Profiler
c = jit_kernel(a, b)

# Reference multiplication using PyTorch
ref_c = a @ b

Expand Down
2 changes: 1 addition & 1 deletion src/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TIR_DEFINE_TL_BUILTIN(ptx_ldmatirx)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_TL_BUILTIN(ptx_stmatirx)
TIR_DEFINE_TL_BUILTIN(ptx_stmatrix)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
42 changes: 21 additions & 21 deletions src/op/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ static constexpr const char *kDynamicAlignment = "tl.dynamic_alignment";
* swizzle, l2_promotion, oob_fill)
*
*/
const Op &create_tma_descriptor();
TVM_DLL const Op &create_tma_descriptor();

/*!
* \brief tvm intrinsics for TMADescriptor creation for image to column load
Expand All @@ -73,23 +73,23 @@ const Op &create_tma_descriptor();
* l2_promotion, oob_fill)
*
*/
const Op &create_tma_im2col_descriptor();
TVM_DLL const Op &create_tma_im2col_descriptor();

/*!
* \brief Create a list of mbarrier with num_threads
*
* create_list_of_mbarrier(num_threads0, num_threads1, ...)
*
*/
const Op &create_list_of_mbarrier();
TVM_DLL const Op &create_list_of_mbarrier();

/*!
* \brief Get the mbarrier with barrier_id
*
* int64_t* GetMBarrier(barrier_id)
*
*/
const Op &get_mbarrier();
TVM_DLL const Op &get_mbarrier();

/*!
* \brief tvm intrinsics for loading data from global tensor descriptor to
Expand All @@ -98,7 +98,7 @@ const Op &get_mbarrier();
* tma_load(descriptor, mbarrier, smem_data, coord_0, coord_1, ...)
*
*/
const Op &tma_load();
TVM_DLL const Op &tma_load();

/*!
* \brief tvm intrinsics for loading image from global tensor to columns in
Expand All @@ -108,7 +108,7 @@ const Op &tma_load();
* image_offset, ...)
*
*/
const Op &tma_load_im2col();
TVM_DLL const Op &tma_load_im2col();

/*!
* \brief tvm intrinsics for storing data from shared memory to global tensor
Expand All @@ -117,119 +117,119 @@ const Op &tma_load_im2col();
* tma_store(descriptor, smem_data, coord_0, coord_1, ...)
*
*/
const Op &tma_store();
TVM_DLL const Op &tma_store();

/*!
* \brief tvm intrinsics for mbarrier wait with parity bit
*
* mbarrier_wait_parity(mbarrier, parity)
*
*/
const Op &mbarrier_wait_parity();
TVM_DLL const Op &mbarrier_wait_parity();

/*!
* \brief tvm intrinsics for mbarrier expect tx
*
* mbarrier_expect_tx(mbarrier, transaction_bytes)
*
*/
const Op &mbarrier_expect_tx();
TVM_DLL const Op &mbarrier_expect_tx();

/*!
* \brief tvm intrinsics for ldmatrix
*
* ptx_ldmatirx(transposed, num, shared_addr, local_addr)
*
*/
const Op &ptx_ldmatirx();
TVM_DLL const Op &ptx_ldmatirx();

/*!
* \brief tvm intrinsics for stmatrix
*
* ptx_ldmatirx(transposed, num, shared_addr, int32_values...)
*
*/
const Op &ptx_stmatirx();
TVM_DLL const Op &ptx_stmatrix();

/*!
* \brief Pack two b16 value into a b32 value
*
* int32 pack_b16(b16_value, b16_value)
*
*/
const Op &pack_b16();
TVM_DLL const Op &pack_b16();

/*!
* \brief Similar to __syncthreads(), but can be used to sync partial threads
*
* sync_thread_partial(num_partial_threads or mbarrier)
*
*/
const Op &sync_thread_partial();
TVM_DLL const Op &sync_thread_partial();

/*!
* \brief Issue a shared memory fence for async operations
*
* FenceProxyAsync()
*
*/
const Op &fence_proxy_async();
TVM_DLL const Op &fence_proxy_async();

/*!
* \brief Indicate arrival of warp issuing TMA_STORE
*
* tma_store_arrive()
*
*/
const Op &tma_store_arrive();
TVM_DLL const Op &tma_store_arrive();

/*!
* \brief Wait for TMA_STORE to finish
*
* tma_store_wait()
*
*/
const Op &tma_store_wait();
TVM_DLL const Op &tma_store_wait();

/*!
* \brief Set reg hint for warp-specialized branched
*
* SetMaxNRegInc(num_reg, is_inc)
*
*/
const Op &set_max_nreg();
TVM_DLL const Op &set_max_nreg();

/*!
* \brief No set reg hint for warp-specialized branched
*
* no_set_max_nreg()
*
*/
const Op &no_set_max_nreg();
TVM_DLL const Op &no_set_max_nreg();

/*!
* \brief Wait the previous wgmma to finish
*
* wait_wgmma(num_mma)
*
*/
const Op &wait_wgmma();
TVM_DLL const Op &wait_wgmma();

/*!
* \brief Synchronize all threads in a grid
*
* sync_grid()
*
*/
const Op &sync_grid();
TVM_DLL const Op &sync_grid();

/*!
* \brief tvm intrinsic for loop continue
*
* loop_break()
*
*/
const Op &loop_break();
TVM_DLL const Op &loop_break();

/*!
* \brief tvm intrinsic for amd matrix core mfma instructions.
Expand Down
2 changes: 1 addition & 1 deletion src/op/elem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ Stmt Copy::LowerLDSMCopy(const LowerArgs &T, arith::Analyzer *analyzer) const {
num = 2;

Array<PrimExpr> args;
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatirx();
const Op &op = is_ldmatrix ? tl::ptx_ldmatirx() : tl::ptx_stmatrix();
args.push_back(static_cast<int>(is_transposed));
args.push_back(num);

Expand Down
Loading