Skip to content

[Bug][Metal] metal codegen hook introduce unexpected side effect #18798

@oraluben

Description

@oraluben

For metal codegen, when we use tvm_callback_metal_compile as a python-side debugger, it also set the format to metallib:

const auto fmetal_compile = tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";

which will make MetalModuleNode treat text source as binary metallib and try to load with newLibraryWithData, and make tvm throw a tvm.error.InternalError: Fail to compile metal lib:Invalid library file

} else {
// Build from library.
auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL);
auto data = dispatch_data_create(source.c_str(), source.length(), q,
^{
});

Also, with tvm-ffi>=0.1.8, it does not gives error message, but a crash instead, which seems not good:
(building from v0.1.8.post2 source is good)

python(69796,0x1f328ec40) malloc: *** error for object 0x8000000000000070: pointer being freed was not allocated
python(69796,0x1f328ec40) malloc: *** set a breakpoint in malloc_error_break to debug

reproducer:

import tilelang
print("Imported tilelang")
from tilelang import tvm as tvm
from time import sleep
# import tilelang.testing
import tilelang.language as T
import json
import torch
import os
print("Imports done", flush=True)


from tilelang.engine.callback import register_metal_postproc_callback

@register_metal_postproc_callback
def _p(code, target):
    print(code)
    return code


@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):

    @T.prim_func
    def gemm(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((K, N), dtype),
            C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(
                T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
                    bx,
                    by,
                ):
            A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
            B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)

            for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
                T.copy(A[by * block_M, ko * block_K], A_shared)
                T.copy(B[ko * block_K, bx * block_N], B_shared)

                for i, j in T.Parallel(block_M, block_N):
                    for k in T.Serial(block_K):
                        C_local[i, j] += A_shared[i, k] * B_shared[k, j]

            T.copy(C_local, C[by * block_M, bx * block_N])

    return gemm


def benchmark(f, n, *args, **kwargs):
    # trigger jit
    f(*args, **kwargs)

    torch.mps.synchronize()
    with torch.mps.profiler.profile(mode="interval,event", wait_until_completed=True):
        start = torch.mps.Event(enable_timing=True)
        end = torch.mps.Event(enable_timing=True)
        start.record()

        for _ in range(n):
            f(*args, **kwargs)

        end.record()

        start.synchronize()
        end.synchronize()

        return start.elapsed_time(end) / 1000


if __name__ == "__main__":
    m = n = k = 128
    torch_dtype = torch.float16
    dtype = 'float16'

    a = torch.randn(m, k, device="mps", dtype=torch_dtype)
    b = torch.randn(k, n, device="mps", dtype=torch_dtype)
    c = torch.zeros(m, n, device="mps", dtype=torch_dtype)

    # torch_add = lambda: torch.matmul(a, b, out=c)
    # torch_add()
    # print(benchmark(torch_add, n=100))

    print("Starting compilation...", flush=True)
    jit_kernel = matmul(m, n, k, 16, 16, 16, dtype=dtype, accum_dtype="float")
    print("Compilation finished.", flush=True)

    # print(jit_kernel.get_kernel_source())
    jit_kernel(a, b, c)
    print(c)
    print(a @ b)

cc @echuraev @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions