Refactor oracle to separate support and selection#34658
Refactor oracle to separate support and selection#34658mgoin wants to merge 1 commit intovllm-project:mainfrom
Conversation
Signed-off-by: mgoin <mgoin64@gmail.com>
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request refactors the MoE backend selection logic in fp8.py and nvfp4.py to reduce code duplication and improve readability. This is achieved by introducing helper functions to check for backend support and select the best available option. The changes are a significant improvement to the codebase's maintainability. My review includes one suggestion for nvfp4.py to further improve consistency with the refactoring done in fp8.py.
| def _check_backend( | ||
| backend: NvFp4MoeBackend, | ||
| ) -> tuple[bool, str | None, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: | ||
| """Check if a backend is supported. Returns (supported, reason, k_cls).""" | ||
| if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: | ||
| supported, reason = is_supported_config_trtllm( | ||
| config, weight_key, activation_key, activation_format | ||
| ) | ||
| return supported, reason, None | ||
| else: | ||
| return ( | ||
| f"NvFp4 MoE backend '{backend.value}' does not support the " | ||
| "deployment configuration." | ||
| k_cls = backend_to_kernel_cls(backend) | ||
| supported, reason = k_cls.is_supported_config( | ||
| k_cls, config, weight_key, activation_key, activation_format | ||
| ) | ||
|
|
||
| def _return_or_raise( | ||
| backend: NvFp4MoeBackend, | ||
| config: FusedMoEConfig, | ||
| weight_key: QuantKey | None, | ||
| activation_key: QuantKey | None, | ||
| activation_format: mk.FusedMoEActivationFormat, | ||
| ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute]]: | ||
| k_cls = backend_to_kernel_cls(backend) | ||
| supported, reason = k_cls.is_supported_config( | ||
| k_cls, config, weight_key, activation_key, activation_format | ||
| return supported, reason, k_cls if supported else None | ||
|
|
||
| def _filter_qualified( | ||
| candidates: list[NvFp4MoeBackend], | ||
| ) -> list[tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]]: | ||
| """Filter candidates to only those that pass is_supported_config.""" | ||
| qualified: list[ | ||
| tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None] | ||
| ] = [] | ||
| for backend in candidates: | ||
| supported, reason, k_cls = _check_backend(backend) | ||
| if supported: | ||
| qualified.append((backend, k_cls)) | ||
| else: | ||
| logger.debug_once( | ||
| f"NvFp4 MoE backend '{backend.value}' does not support the " | ||
| f"deployment configuration{f' since {reason}' if reason else ''}.", | ||
| scope="local", | ||
| ) | ||
| return qualified | ||
|
|
||
| def _select_and_log( | ||
| qualified: list[ | ||
| tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None] | ||
| ], | ||
| error_msg: str, | ||
| ) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]: | ||
| """Select the first qualified backend and log the selection.""" | ||
| if not qualified: | ||
| raise NotImplementedError(error_msg) | ||
|
|
||
| backend, k_cls = qualified[0] | ||
| qualified_names = [b.value for b, _ in qualified] | ||
| logger.info_once( | ||
| f"Using '{backend.value}' NvFp4 MoE backend out of " | ||
| f"potential backends: {qualified_names}.", | ||
| scope="local", | ||
| ) | ||
| if supported: | ||
| logger.info_once(_make_log_backend(backend)) | ||
| return backend, k_cls | ||
| raise ValueError(_make_log_unsupported(backend, reason)) | ||
| return backend, k_cls | ||
|
|
||
| # Handle environment variable overrides. | ||
| if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"): | ||
| if not envs.VLLM_USE_FLASHINFER_MOE_FP4: | ||
| # If the user rejects FlashInfer remove those backends. | ||
| # User explicitly disabled FlashInfer backends. | ||
| for b in FLASHINFER_NVFP4_MOE_BACKENDS: | ||
| AVAILABLE_BACKENDS.remove(b) | ||
|
|
||
| elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"): | ||
| # If user is explicit about backend, validate it. | ||
| # User explicitly requested a specific FlashInfer backend. | ||
| fi_backend = get_flashinfer_moe_backend() | ||
|
|
||
| if fi_backend == FlashinferMoeBackend.TENSORRT_LLM: | ||
| backend = NvFp4MoeBackend.FLASHINFER_TRTLLM | ||
| supported, reason = is_supported_config_trtllm( | ||
| config, weight_key, activation_key, activation_format | ||
| ) | ||
| if supported: | ||
| logger.info_once(_make_log_backend(backend)) | ||
| return backend, None | ||
| else: | ||
| raise ValueError(_make_log_unsupported(backend, reason)) | ||
| else: | ||
| backend = fi_2_vllm_backend_map[fi_backend] | ||
| return _return_or_raise( | ||
| backend, config, weight_key, activation_key, activation_format | ||
| backend = ( | ||
| NvFp4MoeBackend.FLASHINFER_TRTLLM | ||
| if fi_backend == FlashinferMoeBackend.TENSORRT_LLM | ||
| else fi_2_vllm_backend_map[fi_backend] | ||
| ) | ||
| supported, reason, k_cls = _check_backend(backend) | ||
| if supported: | ||
| logger.info_once( | ||
| f"Using '{backend.value}' NvFp4 MoE backend " | ||
| "(explicitly requested).", | ||
| scope="local", | ||
| ) | ||
| return backend, k_cls | ||
| raise ValueError( | ||
| f"NvFp4 MoE backend '{backend.value}' does not support the " | ||
| f"deployment configuration{f' since {reason}' if reason else ''}." | ||
| ) | ||
|
|
||
| else: | ||
| # If the user is not explicit about the backend, try each. | ||
| for backend in FLASHINFER_NVFP4_MOE_BACKENDS: | ||
| if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: | ||
| k_cls = None | ||
| supported, reason = is_supported_config_trtllm( | ||
| config, | ||
| weight_key, | ||
| activation_key, | ||
| activation_format, | ||
| ) | ||
| else: | ||
| k_cls = backend_to_kernel_cls(backend) | ||
| supported, reason = k_cls.is_supported_config( | ||
| k_cls, | ||
| config, | ||
| weight_key, | ||
| activation_key, | ||
| activation_format, | ||
| ) | ||
| if supported: | ||
| logger.info_once(_make_log_backend(backend), scope="local") | ||
| return backend, None | ||
| else: | ||
| logger.debug_once( | ||
| _make_log_unsupported(backend, reason), scope="local" | ||
| ) | ||
|
|
||
| raise NotImplementedError( | ||
| # User enabled FlashInfer but didn't specify which backend. | ||
| # Only consider FlashInfer backends. | ||
| qualified = _filter_qualified(FLASHINFER_NVFP4_MOE_BACKENDS) | ||
| return _select_and_log( | ||
| qualified, | ||
| "Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no " | ||
| "FlashInfer NVFP4 MoE backend supports the configuration." | ||
| "FlashInfer NVFP4 MoE backend supports the configuration.", | ||
| ) | ||
|
|
||
| if envs.VLLM_TEST_FORCE_FP8_MARLIN: | ||
| # Force Marlin backend for testing. | ||
| backend = NvFp4MoeBackend.MARLIN | ||
| return _return_or_raise( | ||
| backend, config, weight_key, activation_key, activation_format | ||
| ) | ||
|
|
||
| # Select kernels in order of backend. | ||
| for backend in AVAILABLE_BACKENDS: | ||
| if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM: | ||
| k_cls = None # type: ignore[assignment] | ||
| supported, reason = is_supported_config_trtllm( | ||
| config, | ||
| weight_key, | ||
| activation_key, | ||
| activation_format, | ||
| ) | ||
| else: | ||
| k_cls = backend_to_kernel_cls(backend) | ||
| supported, reason = k_cls.is_supported_config( | ||
| k_cls, | ||
| config, | ||
| weight_key, | ||
| activation_key, | ||
| activation_format, | ||
| ) | ||
|
|
||
| supported, reason, k_cls = _check_backend(backend) | ||
| if supported: | ||
| logger.info_once(_make_log_backend(backend), scope="local") | ||
| logger.info_once( | ||
| f"Using '{backend.value}' NvFp4 MoE backend (forced via env var).", | ||
| scope="local", | ||
| ) | ||
| return backend, k_cls | ||
| else: | ||
| logger.debug_once(_make_log_unsupported(backend, reason), scope="local") | ||
| raise ValueError( | ||
| f"NvFp4 MoE backend '{backend.value}' does not support the " | ||
| f"deployment configuration{f' since {reason}' if reason else ''}." | ||
| ) |
There was a problem hiding this comment.
For consistency with the refactoring in fp8.py and to reduce code duplication, consider introducing a _check_and_return_explicit helper function. This would encapsulate the logic for handling explicitly requested backends, improving maintainability by ensuring that changes to this logic are made in one place.
def _check_backend(
backend: NvFp4MoeBackend,
) -> tuple[bool, str | None, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""Check if a backend is supported. Returns (supported, reason, k_cls)."""
if backend == NvFp4MoeBackend.FLASHINFER_TRTLLM:
supported, reason = is_supported_config_trtllm(
config, weight_key, activation_key, activation_format
)
return supported, reason, None
else:
k_cls = backend_to_kernel_cls(backend)
supported, reason = k_cls.is_supported_config(
k_cls, config, weight_key, activation_key, activation_format
)
return supported, reason, k_cls if supported else None
def _filter_qualified(
candidates: list[NvFp4MoeBackend],
) -> list[tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]]:
"""Filter candidates to only those that pass is_supported_config."""
qualified: list[
tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]
] = []
for backend in candidates:
supported, reason, k_cls = _check_backend(backend)
if supported:
qualified.append((backend, k_cls))
else:
logger.debug_once(
f"NvFp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration{f' since {reason}' if reason else ''}.",
scope="local",
)
return qualified
def _select_and_log(
qualified: list[
tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]
],
error_msg: str,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""Select the first qualified backend and log the selection."""
if not qualified:
raise NotImplementedError(error_msg)
backend, k_cls = qualified[0]
qualified_names = [b.value for b, _ in qualified]
logger.info_once(
f"Using '{backend.value}' NvFp4 MoE backend out of "
f"potential backends: {qualified_names}.",
scope="local",
)
return backend, k_cls
def _check_and_return_explicit(
backend: NvFp4MoeBackend,
selection_reason: str,
) -> tuple[NvFp4MoeBackend, type[mk.FusedMoEPermuteExpertsUnpermute] | None]:
"""Check and return an explicitly requested backend, or raise."""
supported, reason, k_cls = _check_backend(backend)
if supported:
logger.info_once(
f"Using '{backend.value}' NvFp4 MoE backend ({selection_reason}).",
scope="local",
)
return backend, k_cls
raise ValueError(
f"NvFp4 MoE backend '{backend.value}' does not support the "
f"deployment configuration{f' since {reason}' if reason else ''}."
)
# Handle environment variable overrides.
if envs.is_set("VLLM_USE_FLASHINFER_MOE_FP4"):
if not envs.VLLM_USE_FLASHINFER_MOE_FP4:
# User explicitly disabled FlashInfer backends.
for b in FLASHINFER_NVFP4_MOE_BACKENDS:
AVAILABLE_BACKENDS.remove(b)
elif envs.is_set("VLLM_FLASHINFER_MOE_BACKEND"):
# User explicitly requested a specific FlashInfer backend.
fi_backend = get_flashinfer_moe_backend()
backend = (
NvFp4MoeBackend.FLASHINFER_TRTLLM
if fi_backend == FlashinferMoeBackend.TENSORRT_LLM
else fi_2_vllm_backend_map[fi_backend]
)
return _check_and_return_explicit(backend, "explicitly requested")
else:
# User enabled FlashInfer but didn't specify which backend.
# Only consider FlashInfer backends.
qualified = _filter_qualified(FLASHINFER_NVFP4_MOE_BACKENDS)
return _select_and_log(
qualified,
"Found VLLM_USE_FLASHINFER_MOE_FP4=1, but no "
"FlashInfer NVFP4 MoE backend supports the configuration.",
)
if envs.VLLM_TEST_FORCE_FP8_MARLIN:
# Force Marlin backend for testing.
return _check_and_return_explicit(NvFp4MoeBackend.MARLIN,
"forced via env var")
Purpose
Test Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.