Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9358d89
[wip] feat: add nvrtc backend
lucifer1004 May 8, 2025
bace4a2
[wip] fix: handle out_idx
lucifer1004 May 9, 2025
024f3c4
[wip] refactor: move lib logic to libgen
lucifer1004 May 9, 2025
d3f92b0
feat: cache for nvrtc backend
lucifer1004 May 9, 2025
cac51fe
Merge remote-tracking branch 'upstream/main' into nvrtc
lucifer1004 May 9, 2025
037b0fe
fmt: run format
lucifer1004 May 9, 2025
3b70fd0
fix: handle cuda bindings import error
lucifer1004 May 9, 2025
63dd099
fix: handle cuda bindings import error
lucifer1004 May 9, 2025
eeca41b
fix: handle cuda bindings import error
lucifer1004 May 9, 2025
0c33224
fix: handle cuda bindings import error
lucifer1004 May 9, 2025
a04ab35
Merge remote-tracking branch 'upstream/main' into nvrtc
lucifer1004 May 9, 2025
2a31cc3
fix: get kernel source
lucifer1004 May 9, 2025
243b642
refactor: speedup pyimport
lucifer1004 May 9, 2025
4f0d1ce
Merge remote-tracking branch 'upstream/main' into nvrtc
lucifer1004 May 9, 2025
5672187
Merge remote-tracking branch 'upstream/main' into nvrtc
lucifer1004 May 20, 2025
8c7a286
Merge remote-tracking branch 'upstream/main' into nvrtc
lucifer1004 May 22, 2025
8effff8
Merge branch 'main' of https://github.com/tile-ai/tilelang into nvrtc
LeiWang1999 Jun 4, 2025
c314c64
Improve error handling for missing cuda-python dependency in nvrtc ba…
LeiWang1999 Jun 5, 2025
c77eb46
Merge branch 'main' of https://github.com/tile-ai/tilelang into nvrtc
LeiWang1999 Jun 5, 2025
d730e17
Enhance nvrtc backend error handling by introducing a flag to check f…
LeiWang1999 Jun 5, 2025
ed191a8
Update README.md to include recent NVRTC Backend addition, highlighti…
LeiWang1999 Jun 5, 2025
733e38e
fix tl_templates
lucifer1004 Jun 5, 2025
03694ac
ensure CUDA context
lucifer1004 Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />

## Latest News
- 05/06/2025 ✨: Added [NVRTC Backend](https://github.com/tile-ai/tilelang/pull/461) to significantly reduce compilation time for cute templates!
- 14/04/2025 🚀: Added high-performance FlashMLA implementation for AMD MI300X, achieving performance parity with hand-optimized assembly kernels of Aiter! See [example_mla_amd](./examples/deepseek_mla/amd/README.md) for details.
- 03/03/2025 🚀: Added high-performance MLA Decoding support using only 80 lines of Python code, achieving performance on par with FlashMLA on H100 (see [example_mla_decode.py](./examples/deepseek_mla/example_mla_decode.py))! We also provide [documentation](./examples/deepseek_mla/README.md) explaining how TileLang achieves this.
- 02/15/2025 ✨: Added WebGPU Codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)!
Expand Down
3 changes: 3 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
// Licensed under the MIT License.
#pragma once

#ifndef __CUDACC_RTC__
#include <cuda_runtime.h>
#endif

#include <cutlass/fast_math.h>
#include <cutlass/numeric_types.h>
#include <math_constants.h>
Expand Down
8 changes: 4 additions & 4 deletions src/tl_templates/cuda/copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
static_assert(N == 16 || N == 8 || N == 4);
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2;"
#else
Expand All @@ -36,7 +36,7 @@ TL_DEVICE void cp_async_gs(void const *const smem_addr, void *global_ptr) {
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N));
} else {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2;"
#else
Expand All @@ -54,7 +54,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
int bytes = cond ? N : 0;
unsigned int addr = smem_ptr_to_uint(smem_addr);
if constexpr (N == 16) {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
Expand All @@ -63,7 +63,7 @@ TL_DEVICE void cp_async_gs_conditional(void const *const smem_addr,
::"r"(addr),
"l"((void *)(global_ptr)), "n"(N), "r"(bytes));
} else {
__asm__ __volatile__(
asm volatile(
#if TL_ENABLE_L2_PREFETCH
"cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;"
#else
Expand Down
2 changes: 2 additions & 0 deletions src/tl_templates/cuda/copy_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// Licensed under the MIT License.
#pragma once

#ifndef __CUDACC_RTC__
#include <cuda.h>
#endif

#include "common.h"

Expand Down
5 changes: 4 additions & 1 deletion src/tl_templates/cuda/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

#include "./cuda_fp8.h"
#include "common.h"
#include <stdio.h>

#ifndef __CUDACC_RTC__
#include <cstdio>
#endif

// Template declaration for device-side debug printing (variable only)
template <typename T> __device__ void debug_print_var(const char *msg, T var);
Expand Down
5 changes: 3 additions & 2 deletions src/tl_templates/cuda/gemm_sm89.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
// Licensed under the MIT License.
#pragma once

#include "common.h"
#include "cuda_fp8.h"
#include <cute/algorithm/clear.hpp>
#include <cute/arch/mma_sm80.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cute/atom/mma_traits.hpp>
#include <cute/underscore.hpp>

#include "common.h"
#include "cuda_fp8.h"

namespace cute {

template <typename A_type, typename B_type, typename C_type, int num_warp_m,
Expand Down
123 changes: 123 additions & 0 deletions src/tl_templates/cuda/nvrtc_std.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
* All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#ifdef __CUDACC_RTC__

using int8_t = signed char;
using uint8_t = unsigned char;
using int16_t = signed short;
using uint16_t = unsigned short;
using int32_t = signed int;
using uint32_t = unsigned int;
using int64_t = signed long long;
using uint64_t = unsigned long long;
using cuuint64_t = unsigned long long;

#ifndef CU_TENSOR_MAP_NUM_QWORDS
#define CU_TENSOR_MAP_NUM_QWORDS 16

struct CUtensorMap_st {
#if defined(__cplusplus) && (__cplusplus >= 201103L)
alignas(64)
#elif __STDC_VERSION__ >= 201112L
_Alignas(64)
#endif
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
};

using CUtensorMap = CUtensorMap_st;
#endif

namespace std {

template <class T, T v> struct integral_constant {
static constexpr T value = v;

using value_type = T;
using type = integral_constant;

__device__ constexpr operator value_type() const noexcept { return value; }

__device__ constexpr value_type operator()() const noexcept { return value; }
};

using false_type = integral_constant<bool, false>;
using true_type = integral_constant<bool, true>;

template <class T, class U> struct is_same : false_type {};

template <class T> struct is_same<T, T> : true_type {};

template <class T, class U>
inline constexpr bool is_same_v = is_same<T, U>::value;

namespace index_sequence_impl {

// Based on https://stackoverflow.com/a/32223343/11717224
template <size_t... Ints> struct index_sequence {
using type = index_sequence;
using value_type = size_t;
static constexpr size_t size() noexcept { return sizeof...(Ints); }
};

template <class Sequence1, class Sequence2> struct _merge_and_renumber;

template <size_t... I1, size_t... I2>
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};

template <size_t N>
struct make_index_sequence
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
typename make_index_sequence<N - N / 2>::type> {};

template <> struct make_index_sequence<0> : index_sequence<> {};
template <> struct make_index_sequence<1> : index_sequence<0> {};

} // namespace index_sequence_impl

template <size_t... Ns>
using index_sequence = index_sequence_impl::index_sequence<Ns...>;

template <size_t N>
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;

template <typename T> constexpr T min(T a, T b) { return a < b ? a : b; }

template <typename T> constexpr T max(T a, T b) { return a > b ? a : b; }

template <bool B, class T, class F> struct conditional {
using type = T;
};

template <class T, class F> struct conditional<false, T, F> {
using type = F;
};

template <bool B, class T, class F>
using conditional_t = typename conditional<B, T, F>::type;

template <bool B, class T = void> struct enable_if {};

template <class T> struct enable_if<true, T> {
using type = T;
};
} // namespace std

#endif
4 changes: 2 additions & 2 deletions tilelang/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ def cached(
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Optional[Literal["dlpack", "ctypes", "cython"]] = "cython",
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
verbose: Optional[bool] = False,
pass_configs: Optional[dict] = None,
) -> JITKernel:
"""
Caches and reuses compiled kerne(ls (using KernelCache class).
Caches and reuses compiled kernels (using KernelCache class).
"""
return _kernel_cache_instance.cached(
func,
Expand Down
51 changes: 34 additions & 17 deletions tilelang/cache/kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
# Licensed under the MIT License.
"""The cache utils with class and database persistence - KernelCache Class"""

import os
import json
import logging
import os
import shutil
from pathlib import Path
import threading
from hashlib import sha256
from typing import Callable, List, Literal, Union, Optional
from pathlib import Path
from typing import Callable, List, Literal, Optional, Union

import cloudpickle
from tvm.target import Target
from tvm.tir import PrimFunc
from tilelang.jit import JITKernel
from tilelang.engine.param import KernelParam
import threading
import cloudpickle
import logging

from tilelang.engine.param import KernelParam
from tilelang.env import TILELANG_CACHE_DIR, is_cache_enabled
from tilelang.jit import JITKernel
from tilelang.version import __version__

KERNEL_PATH = "kernel.cu"
WRAPPED_KERNEL_PATH = "wrapped_kernel.cu"
KERNEL_LIB_PATH = "kernel_lib.so"
KERNEL_CUBIN_PATH = "kernel.cubin"
KERNEL_PY_PATH = "kernel.py"
PARAMS_PATH = "params.pkl"


Expand All @@ -38,6 +41,7 @@ class KernelCache:
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython"

cache_dir: Path = Path(TILELANG_CACHE_DIR)

Expand Down Expand Up @@ -68,7 +72,7 @@ def _generate_key(
self,
func: Callable,
out_idx: List[int],
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
args=None,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
Expand All @@ -88,6 +92,7 @@ def _generate_key(
Returns:
str: SHA256 hash key for the kernel configuration.
"""
self.execution_backend = execution_backend
func_binary = cloudpickle.dumps(func.script())
key_data = {
"version": __version__,
Expand All @@ -101,8 +106,10 @@ def _generate_key(
"execution_backend": execution_backend,
"pass_configs": pass_configs,
}
key_string = json.dumps(key_data, sort_keys=True) # Sort keys to ensure consistency
return sha256(key_string.encode()).hexdigest() # Use SHA256 to generate hash key
# Sort keys to ensure consistency
key_string = json.dumps(key_data, sort_keys=True)
# Use SHA256 to generate hash key
return sha256(key_string.encode()).hexdigest()

def cached(
self,
Expand All @@ -111,7 +118,7 @@ def cached(
*args,
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
verbose: bool = False,
pass_configs: dict = None,
) -> JITKernel:
Expand Down Expand Up @@ -253,9 +260,15 @@ def _save_kernel_to_disk(self, key: str, kernel: JITKernel, func: Callable = Non

# Save kernel library
try:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
if self.execution_backend == "nvrtc":
kernel_lib_path = os.path.join(cache_path, KERNEL_CUBIN_PATH)
else:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
src_lib_path = kernel.adapter.libpath
shutil.copy(src_lib_path, kernel_lib_path)
if self.execution_backend == "nvrtc":
shutil.copy(
src_lib_path.replace(".cubin", ".py"), os.path.join(cache_path, KERNEL_PY_PATH))
except Exception as e:
self.logger.error(f"Error saving kernel library to disk: {e}")

Expand All @@ -273,7 +286,7 @@ def _load_kernel_from_disk(
target: Union[str, Target] = "auto",
target_host: Union[str, Target] = None,
out_idx: List[int] = None,
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
pass_configs: dict = None,
func: Callable = None,
) -> JITKernel:
Expand Down Expand Up @@ -306,7 +319,10 @@ def _load_kernel_from_disk(
except Exception as e:
self.logger.error(f"Error loading wrapped kernel source code from disk: {e}")

kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)
if self.execution_backend == "nvrtc":
kernel_lib_path = os.path.join(cache_path, KERNEL_CUBIN_PATH)
else:
kernel_lib_path = os.path.join(cache_path, KERNEL_LIB_PATH)

# Load kernel parameters
try:
Expand Down Expand Up @@ -334,14 +350,15 @@ def _load_kernel_from_disk(
def _clear_disk_cache(self):
"""
Removes all cached kernels from disk.

Note:
This operation will delete the entire cache directory and recreate it empty.
Use with caution as this operation cannot be undone.
"""
try:
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir) # Delete entire cache directory
os.makedirs(self.cache_dir, exist_ok=True) # Re-create cache directory
# Re-create cache directory
os.makedirs(self.cache_dir, exist_ok=True)
except Exception as e:
self.logger.error(f"Error clearing disk cache: {e}")
Loading
Loading