Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate --disable-flashinfer and --disable-flashinfer-sampling #2065

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Deprecate --disable-flashinfer and --disable-flashinfer-sampling
  • Loading branch information
merrymercy committed Nov 18, 2024
commit 14e2aa81681b0f60fcba0ed07038b6aa6d638380
48 changes: 22 additions & 26 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,6 @@ class ServerArgs:
grammar_backend: Optional[str] = "outlines"

# Optimization/debug options
disable_flashinfer: bool = False
disable_flashinfer_sampling: bool = False
disable_radix_cache: bool = False
disable_jump_forward: bool = False
disable_cuda_graph: bool = False
Expand Down Expand Up @@ -179,20 +177,6 @@ def __post_init__(self):
self.chunked_prefill_size //= 4 # make it 2048
self.cuda_graph_max_bs = 4

# Deprecation warnings
if self.disable_flashinfer:
logger.warning(
"The option '--disable-flashinfer' will be deprecated in the next release. "
"Please use '--attention-backend triton' instead."
)
self.attention_backend = "triton"
if self.disable_flashinfer_sampling:
logger.warning(
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
"Please use '--sampling-backend pytorch' instead. "
)
self.sampling_backend = "pytorch"

if not is_flashinfer_available():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"
Expand Down Expand Up @@ -615,16 +599,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
)

# Optimization/debug options
parser.add_argument(
"--disable-flashinfer",
action="store_true",
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action="store_true",
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
)
parser.add_argument(
"--disable-radix-cache",
action="store_true",
Expand Down Expand Up @@ -733,6 +707,18 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Delete the model checkpoint after loading the model.",
)

# Deprecated arguments
parser.add_argument(
"--disable-flashinfer",
action=DeprecatedAction,
help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
)
parser.add_argument(
"--disable-flashinfer-sampling",
action=DeprecatedAction,
help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
)

@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
Expand Down Expand Up @@ -826,3 +812,13 @@ def __call__(self, parser, namespace, values, option_string=None):
getattr(namespace, self.dest)[name] = path
else:
getattr(namespace, self.dest)[lora_path] = lora_path


class DeprecatedAction(argparse.Action):
def __init__(self, option_strings, dest, nargs=0, **kwargs):
super(DeprecatedAction, self).__init__(
option_strings, dest, nargs=nargs, **kwargs
)

def __call__(self, parser, namespace, values, option_string=None):
raise ValueError(self.help)
2 changes: 2 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def is_flashinfer_available():
Check whether flashinfer is available.
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
"""
if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
return False
return torch.cuda.is_available() and not is_hip()


Expand Down