-
Notifications
You must be signed in to change notification settings - Fork 590
Open
Labels
Description
I tried to compile single_prefill_with_kv_cache using torch.compile.
import torch
from flashinfer import single_prefill_with_kv_cache
data_type = torch.bfloat16
QH=64
KH=8
S=1024
D=128
def generate_data():
q = torch.randn(S, QH, D, device='cuda', dtype=data_type)
k = torch.randn(S, KH, D, device='cuda', dtype=data_type)
v = torch.randn(S, KH, D, device='cuda', dtype=data_type)
return q, k, v
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
torch.library.define(
"mylib::custom_func_flashinfer",
"(Tensor q, Tensor k, Tensor v, Tensor custom_mask) -> Tensor",
)
@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v, custom_mask):
return single_prefill_with_kv_cache(
q, k, v, custom_mask=custom_mask
)
@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v, custom_mask):
return torch.empty_like(q)
def attn(q, k, v, custom_mask=None):
return torch.ops.mylib.custom_func_flashinfer(q, k, v, custom_mask=custom_mask)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)
for i in range(10):
q, k, v = generate_data()
mask = torch.tril(
torch.full((S, S), True, device="cuda:0"),
)
o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
print(run_time)cause following runtime error
/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py:37: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
@torch.library.impl_abstract("mylib::custom_func_flashinfer")
Traceback (most recent call last):
File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <module>
o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 21, in timed
result = fn()
^^^^
File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <lambda>
o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 42, in attn
def attn(q, k, v, custom_mask=None):
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 987, in forward
return compiled_fn(full_args)
^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 217, in runtime_wrapper
all_outs = call_func_at_runtime_with_args(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
out = normalize_as_list(f(args))
^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 451, in wrapper
return compiled_fn(runtime_args)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1131, in __call__
return self.current_callable(inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 993, in run
return compiled_fn(new_inputs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 373, in deferred_cudagraphify
fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 403, in cudagraphify
return manager.add_function(
^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2089, in add_function
return fn, fn(inputs)
^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1841, in run
out = self._run(new_inputs, function_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1932, in _run
return self.run_eager(new_inputs, function_id)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2055, in run_eager
return node.run(new_inputs)
^^^^^^^^^^^^^^^^^^^^
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 646, in run
check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1699, in check_memory_pool
raise RuntimeError(msg)
RuntimeError: These live storage data ptrs are in the cudagraph pool but not accounted for as an output of cudagraph trees:
Data Pointer: 22959854977024, history: