|
19 | 19 | @pytest.fixture(autouse=True) |
20 | 20 | def enable_batch_invariant_mode(): |
21 | 21 | """Automatically enable batch invariant kernel overrides for all tests.""" |
22 | | - old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") |
23 | | - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "1" |
| 22 | + old_value = os.environ.get("VLLM_BATCH_INVARIANT") |
| 23 | + os.environ["VLLM_BATCH_INVARIANT"] = "1" |
24 | 24 | yield |
25 | 25 | # Restore original value after test |
26 | 26 | if old_value is None: |
27 | | - os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) |
| 27 | + os.environ.pop("VLLM_BATCH_INVARIANT", None) |
28 | 28 | else: |
29 | | - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value |
| 29 | + os.environ["VLLM_BATCH_INVARIANT"] = old_value |
30 | 30 |
|
31 | 31 |
|
32 | 32 | def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: |
@@ -231,10 +231,10 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend): |
231 | 231 | # For batch invariance, disable custom all-reduce to ensure deterministic |
232 | 232 | # all-reduce operations (custom all-reduce may not be deterministic) |
233 | 233 | from vllm.model_executor.layers.batch_invariant import ( |
234 | | - vllm_kernel_override_batch_invariant, |
| 234 | + vllm_is_batch_invariant, |
235 | 235 | ) |
236 | 236 |
|
237 | | - disable_custom_ar = vllm_kernel_override_batch_invariant() |
| 237 | + disable_custom_ar = vllm_is_batch_invariant() |
238 | 238 |
|
239 | 239 | if disable_custom_ar: |
240 | 240 | print(f"\n{'=' * 80}") |
@@ -494,8 +494,8 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): |
494 | 494 | os.environ["VLLM_ATTENTION_BACKEND"] = backend |
495 | 495 |
|
496 | 496 | # CRITICAL: Disable batch invariance for this test |
497 | | - old_value = os.environ.get("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT") |
498 | | - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = "0" |
| 497 | + old_value = os.environ.get("VLLM_BATCH_INVARIANT") |
| 498 | + os.environ["VLLM_BATCH_INVARIANT"] = "0" |
499 | 499 |
|
500 | 500 | try: |
501 | 501 | seed = int(os.getenv("VLLM_TEST_SEED", "12345")) |
@@ -687,9 +687,9 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend): |
687 | 687 | finally: |
688 | 688 | # Restore original value |
689 | 689 | if old_value is None: |
690 | | - os.environ.pop("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", None) |
| 690 | + os.environ.pop("VLLM_BATCH_INVARIANT", None) |
691 | 691 | else: |
692 | | - os.environ["VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"] = old_value |
| 692 | + os.environ["VLLM_BATCH_INVARIANT"] = old_value |
693 | 693 |
|
694 | 694 |
|
695 | 695 | @hopper_only |
@@ -718,10 +718,10 @@ def test_decode_logprobs_match_prefill_logprobs(backend): |
718 | 718 | tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) |
719 | 719 |
|
720 | 720 | from vllm.model_executor.layers.batch_invariant import ( |
721 | | - vllm_kernel_override_batch_invariant, |
| 721 | + vllm_is_batch_invariant, |
722 | 722 | ) |
723 | 723 |
|
724 | | - disable_custom_ar = vllm_kernel_override_batch_invariant() |
| 724 | + disable_custom_ar = vllm_is_batch_invariant() |
725 | 725 |
|
726 | 726 | if disable_custom_ar: |
727 | 727 | print(f"\n{'=' * 80}") |
|
0 commit comments