Skip to content

Commit

Permalink
feat: add JIT compilation support for FA3 templates (#672)
Browse files Browse the repository at this point in the history
Follow up work of #667
  • Loading branch information
yzh119 authored Dec 17, 2024
1 parent d2ebd1e commit d4e8d79
Show file tree
Hide file tree
Showing 6 changed files with 610 additions and 23 deletions.
6 changes: 3 additions & 3 deletions flashinfer/jit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
batch_prefill_sm90_templ,
)
from .batch_prefill_templ import batch_prefill_suffix, batch_prefill_templ
from .core import load_cuda_ops
from .core import load_cuda_ops, sm90a_nvcc_flags
from .env import FLASHINFER_GEN_SRC_DIR
from .single_decode_templ import (
customizable_single_decode_templ,
Expand Down Expand Up @@ -333,7 +333,7 @@ def gen_single_prefill_sm90_module(*args):
source_paths.append(path)
write_if_different(path, source)

return load_cuda_ops(uri, source_paths)
return load_cuda_ops(uri, source_paths, extra_cuda_cflags=sm90a_nvcc_flags)


def get_batch_prefill_sources(
Expand Down Expand Up @@ -445,7 +445,7 @@ def gen_batch_prefill_sm90_module(*args):
source_paths.append(path)
write_if_different(path, source)

return load_cuda_ops(uri, source_paths)
return load_cuda_ops(uri, source_paths, extra_cuda_cflags=sm90a_nvcc_flags)


def get_customize_single_decode_sources(
Expand Down
Loading

0 comments on commit d4e8d79

Please sign in to comment.