-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Open
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug
Description
For metal codegen, when we use tvm_callback_metal_compile as a python-side debugger, it also set the format to metallib:
tvm/src/target/source/codegen_metal.cc
Lines 441 to 442 in 52e4547
| const auto fmetal_compile = tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile"); | |
| std::string fmt = fmetal_compile ? "metallib" : "metal"; |
which will make MetalModuleNode treat text source as binary metallib and try to load with newLibraryWithData, and make tvm throw a tvm.error.InternalError: Fail to compile metal lib:Invalid library file
tvm/src/runtime/metal/metal_module.mm
Lines 123 to 128 in 52e4547
| } else { | |
| // Build from library. | |
| auto q = dispatch_queue_create("q", DISPATCH_QUEUE_SERIAL); | |
| auto data = dispatch_data_create(source.c_str(), source.length(), q, | |
| ^{ | |
| }); |
Also, with tvm-ffi>=0.1.8, it does not gives error message, but a crash instead, which seems not good:
(building from v0.1.8.post2 source is good)
python(69796,0x1f328ec40) malloc: *** error for object 0x8000000000000070: pointer being freed was not allocated
python(69796,0x1f328ec40) malloc: *** set a breakpoint in malloc_error_break to debug
reproducer:
import tilelang
print("Imported tilelang")
from tilelang import tvm as tvm
from time import sleep
# import tilelang.testing
import tilelang.language as T
import json
import torch
import os
print("Imports done", flush=True)
from tilelang.engine.callback import register_metal_postproc_callback
@register_metal_postproc_callback
def _p(code, target):
print(code)
return code
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared")
B_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared")
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared)
T.copy(B[ko * block_K, bx * block_N], B_shared)
for i, j in T.Parallel(block_M, block_N):
for k in T.Serial(block_K):
C_local[i, j] += A_shared[i, k] * B_shared[k, j]
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm
def benchmark(f, n, *args, **kwargs):
# trigger jit
f(*args, **kwargs)
torch.mps.synchronize()
with torch.mps.profiler.profile(mode="interval,event", wait_until_completed=True):
start = torch.mps.Event(enable_timing=True)
end = torch.mps.Event(enable_timing=True)
start.record()
for _ in range(n):
f(*args, **kwargs)
end.record()
start.synchronize()
end.synchronize()
return start.elapsed_time(end) / 1000
if __name__ == "__main__":
m = n = k = 128
torch_dtype = torch.float16
dtype = 'float16'
a = torch.randn(m, k, device="mps", dtype=torch_dtype)
b = torch.randn(k, n, device="mps", dtype=torch_dtype)
c = torch.zeros(m, n, device="mps", dtype=torch_dtype)
# torch_add = lambda: torch.matmul(a, b, out=c)
# torch_add()
# print(benchmark(torch_add, n=100))
print("Starting compilation...", flush=True)
jit_kernel = matmul(m, n, k, 16, 16, 16, dtype=dtype, accum_dtype="float")
print("Compilation finished.", flush=True)
# print(jit_kernel.get_kernel_source())
jit_kernel(a, b, c)
print(c)
print(a @ b)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address itPRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug