Skip to content

Runtime error with single_prefill_with_kv_cache while Compilation #541

@YudiZh

Description

@YudiZh

I tried to compile single_prefill_with_kv_cache using torch.compile.

import torch
from flashinfer import single_prefill_with_kv_cache

data_type = torch.bfloat16

QH=64
KH=8
S=1024
D=128

def generate_data():
    q = torch.randn(S, QH, D, device='cuda', dtype=data_type)
    k = torch.randn(S, KH, D, device='cuda', dtype=data_type)
    v = torch.randn(S, KH, D, device='cuda', dtype=data_type)
    return q, k, v

def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

torch.library.define(
    "mylib::custom_func_flashinfer",
    "(Tensor q, Tensor k, Tensor v, Tensor custom_mask) -> Tensor",
)

@torch.library.impl("mylib::custom_func_flashinfer", "cuda")
def custom_func_flashinfer(q, k, v, custom_mask):
    return single_prefill_with_kv_cache(
        q, k, v, custom_mask=custom_mask
    )

@torch.library.impl_abstract("mylib::custom_func_flashinfer")
def custom_func_flashinfer_abstract(q, k, v, custom_mask):
    return torch.empty_like(q)


def attn(q, k, v, custom_mask=None):
    return torch.ops.mylib.custom_func_flashinfer(q, k, v, custom_mask=custom_mask)
attn = torch.compile(attn, mode="reduce-overhead", fullgraph=True)


for i in range(10):
    q, k, v = generate_data()
    mask = torch.tril(
        torch.full((S, S), True, device="cuda:0"),
    )
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
    print(run_time)

cause following runtime error

/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py:37: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("mylib::custom_func_flashinfer")
Traceback (most recent call last):
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <module>
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 21, in timed
    result = fn()
             ^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 52, in <lambda>
    o, run_time = timed(lambda: attn(q, k, v, custom_mask=mask))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/project/code_test/flashinfer_test/compilation.py", line 42, in attn
    def attn(q, k, v, custom_mask=None):
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 987, in forward
    return compiled_fn(full_args)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 217, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/utils.py", line 120, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
                            ^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 451, in wrapper
    return compiled_fn(runtime_args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 1131, in __call__
    return self.current_callable(inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 993, in run
    return compiled_fn(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 373, in deferred_cudagraphify
    fn, out = cudagraphify(model, inputs, new_static_input_idxs, *args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 403, in cudagraphify
    return manager.add_function(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2089, in add_function
    return fn, fn(inputs)
               ^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1841, in run
    out = self._run(new_inputs, function_id)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1932, in _run
    return self.run_eager(new_inputs, function_id)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 2055, in run_eager
    return node.run(new_inputs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 646, in run
    check_memory_pool(self.device_index, self.cuda_graphs_pool, out_refs)
  File "/data/home/ydzhang/miniconda3/envs/torch24/lib/python3.11/site-packages/torch/_inductor/cudagraph_trees.py", line 1699, in check_memory_pool
    raise RuntimeError(msg)
RuntimeError: These live storage data ptrs are in the cudagraph pool but not accounted for as an output of cudagraph trees: 

Data Pointer: 22959854977024, history: 

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions