Skip to content

Commit 9e9edb2

Browse files
authored
Do not import NVSHMEM in the AoT script unless explicitly requested (#1506)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 20ab8ab commit 9e9edb2

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

flashinfer/aot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from .activation import act_func_def_str, gen_act_and_mul_module
1313
from .cascade import gen_cascade_module
14-
from .comm.nvshmem import gen_nvshmem_module
1514
from .fp4_quantization import gen_fp4_quantization_sm100_module
1615
from .fused_moe import gen_cutlass_fused_moe_sm100_module
1716
from .gemm import gen_gemm_module, gen_gemm_sm90_module, gen_gemm_sm100_module
@@ -372,6 +371,9 @@ def gen_all_modules(
372371

373372
if add_comm:
374373
from .comm import gen_trtllm_comm_module, gen_vllm_comm_module
374+
from .comm.nvshmem import gen_nvshmem_module
375+
376+
jit_specs.append(gen_nvshmem_module())
375377

376378
if has_sm100:
377379
jit_specs.append(gen_trtllm_comm_module())
@@ -381,7 +383,6 @@ def gen_all_modules(
381383
jit_specs += [
382384
gen_cascade_module(),
383385
gen_norm_module(),
384-
gen_nvshmem_module(),
385386
gen_page_module(),
386387
gen_quantization_module(),
387388
gen_rope_module(),

0 commit comments

Comments
 (0)