From 11f881d173c4744a3ebf31736c264a0b0af4396f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 17 Nov 2024 16:20:58 -0800 Subject: [PATCH] Deprecate --disable-flashinfer and --disable-flashinfer-sampling (#2065) --- python/sglang/srt/server_args.py | 48 ++++++++++++++---------------- python/sglang/srt/utils.py | 2 ++ test/srt/test_torch_compile_moe.py | 3 +- 3 files changed, 25 insertions(+), 28 deletions(-) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 26c339e649..2a4b0d67ef 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -116,8 +116,6 @@ class ServerArgs: grammar_backend: Optional[str] = "outlines" # Optimization/debug options - disable_flashinfer: bool = False - disable_flashinfer_sampling: bool = False disable_radix_cache: bool = False disable_jump_forward: bool = False disable_cuda_graph: bool = False @@ -179,20 +177,6 @@ def __post_init__(self): self.chunked_prefill_size //= 4 # make it 2048 self.cuda_graph_max_bs = 4 - # Deprecation warnings - if self.disable_flashinfer: - logger.warning( - "The option '--disable-flashinfer' will be deprecated in the next release. " - "Please use '--attention-backend triton' instead." - ) - self.attention_backend = "triton" - if self.disable_flashinfer_sampling: - logger.warning( - "The option '--disable-flashinfer-sampling' will be deprecated in the next release. " - "Please use '--sampling-backend pytorch' instead. " - ) - self.sampling_backend = "pytorch" - if not is_flashinfer_available(): self.attention_backend = "triton" self.sampling_backend = "pytorch" @@ -615,16 +599,6 @@ def add_cli_args(parser: argparse.ArgumentParser): ) # Optimization/debug options - parser.add_argument( - "--disable-flashinfer", - action="store_true", - help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.", - ) - parser.add_argument( - "--disable-flashinfer-sampling", - action="store_true", - help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.", - ) parser.add_argument( "--disable-radix-cache", action="store_true", @@ -733,6 +707,18 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Delete the model checkpoint after loading the model.", ) + # Deprecated arguments + parser.add_argument( + "--disable-flashinfer", + action=DeprecatedAction, + help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.", + ) + parser.add_argument( + "--disable-flashinfer-sampling", + action=DeprecatedAction, + help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size @@ -826,3 +812,13 @@ def __call__(self, parser, namespace, values, option_string=None): getattr(namespace, self.dest)[name] = path else: getattr(namespace, self.dest)[lora_path] = lora_path + + +class DeprecatedAction(argparse.Action): + def __init__(self, option_strings, dest, nargs=0, **kwargs): + super(DeprecatedAction, self).__init__( + option_strings, dest, nargs=nargs, **kwargs + ) + + def __call__(self, parser, namespace, values, option_string=None): + raise ValueError(self.help) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 7e6174ad87..8df6d7b7e7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -71,6 +71,8 @@ def is_flashinfer_available(): Check whether flashinfer is available. As of Oct. 6, 2024, it is only available on NVIDIA GPUs. """ + if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false": + return False return torch.cuda.is_available() and not is_hip() diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py index 934ef34994..d19ab2bbda 100644 --- a/test/srt/test_torch_compile_moe.py +++ b/test/srt/test_torch_compile_moe.py @@ -65,8 +65,7 @@ def test_throughput(self): tok = time.time() print(f"{res=}") throughput = max_tokens / (tok - tic) - print(f"Throughput: {throughput} tokens/s") - self.assertGreaterEqual(throughput, 290) + self.assertGreaterEqual(throughput, 285) if __name__ == "__main__":