Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add FlexibleArgumentParser to support both underscore and dash in names #5718

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import EngineArgs, FlexibleArgumentParser
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

Expand Down Expand Up @@ -120,7 +120,7 @@ def run_to_completion(profile_dir: Optional[str] = None):


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import time

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import FlexibleArgumentParser

PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501

Expand Down Expand Up @@ -44,7 +44,7 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance with or without automatic '
'prefix caching.')
parser.add_argument('--model',
Expand Down
7 changes: 6 additions & 1 deletion benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
except ImportError:
from backend_request_func import get_tokenizer

try:
from vllm.engine.arg_utils import FlexibleArgumentParser
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser


@dataclass
class BenchmarkMetrics:
Expand Down Expand Up @@ -511,7 +516,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the online serving throughput.")
parser.add_argument(
"--backend",
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)

from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import EngineArgs, FlexibleArgumentParser
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS


Expand Down Expand Up @@ -261,7 +261,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.engine.arg_utils import FlexibleArgumentParser

DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
Expand Down Expand Up @@ -293,7 +294,7 @@ def to_torch_dtype(dt):
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")

parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_aqlm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import os
import sys
from typing import Optional
Expand All @@ -7,6 +6,7 @@
import torch.nn.functional as F

from vllm import _custom_ops as ops
from vllm.engine.arg_utils import FlexibleArgumentParser
from vllm.model_executor.layers.quantization.aqlm import (
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
optimized_dequantize_gemm)
Expand Down Expand Up @@ -137,7 +137,7 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None:

def main():

parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
parser = FlexibleArgumentParser(description="Benchmark aqlm performance.")

# Add arguments
parser.add_argument("--nbooks",
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig

from vllm.engine.arg_utils import FlexibleArgumentParser
from vllm.model_executor.layers.fused_moe.fused_moe import *


Expand Down Expand Up @@ -315,7 +316,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import random
import time
from typing import List, Optional

import torch

from vllm import _custom_ops as ops
from vllm.engine.arg_utils import FlexibleArgumentParser
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random

NUM_BLOCKS = 1024
Expand Down Expand Up @@ -161,7 +161,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the paged attention kernel.")
parser.add_argument("--version",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/benchmark_rope.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import argparse
from itertools import accumulate
from typing import List, Optional

import nvtx
import torch

from vllm.engine.arg_utils import FlexibleArgumentParser
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding,
get_rope)

Expand Down Expand Up @@ -86,7 +86,7 @@ def benchmark_rope_kernels_multi_lora(


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.")
parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/overheads/benchmark_hashing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import argparse
import cProfile
import pstats

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import FlexibleArgumentParser

# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?"
Expand Down Expand Up @@ -47,7 +47,7 @@ def main(args):


if __name__ == "__main__":
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in'
'automatic prefix caching.')
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k')
Expand Down
5 changes: 2 additions & 3 deletions examples/aqlm_example.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import argparse

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import FlexibleArgumentParser


def main():

parser = argparse.ArgumentParser(description='AQLM examples')
parser = FlexibleArgumentParser(description='AQLM examples')

parser.add_argument('--model',
'-m',
Expand Down
3 changes: 2 additions & 1 deletion examples/llm_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Tuple

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.engine.arg_utils import FlexibleArgumentParser


def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
Expand Down Expand Up @@ -55,7 +56,7 @@ def main(args: argparse.Namespace):


if __name__ == '__main__':
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
Expand Down
4 changes: 2 additions & 2 deletions examples/save_sharded_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
tensor_parallel_size=8,
)
"""
import argparse
import dataclasses
import os
import shutil
from pathlib import Path

from vllm import LLM, EngineArgs
from vllm.engine.arg_utils import FlexibleArgumentParser

parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
EngineArgs.add_cli_args(parser)
parser.add_argument("--output",
"-o",
Expand Down
4 changes: 2 additions & 2 deletions examples/tensorize_vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid

from vllm import LLM
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import EngineArgs, FlexibleArgumentParser
from vllm.model_executor.model_loader.tensorizer import (TensorizerArgs,
TensorizerConfig,
tensorize_vllm_model)
Expand Down Expand Up @@ -96,7 +96,7 @@


def parse_args():
parser = argparse.ArgumentParser(
parser = FlexibleArgumentParser(
description="An example script that can be used to serialize and "
"deserialize vLLM models. These models "
"can be loaded using tensorizer directly to the GPU "
Expand Down
5 changes: 2 additions & 3 deletions tests/async_engine/api_server_async_engine.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
import argparse
from typing import Any, Dict

import uvicorn
from fastapi.responses import JSONResponse, Response

import vllm.entrypoints.api_server
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, FlexibleArgumentParser
from vllm.engine.async_llm_engine import AsyncLLMEngine

app = vllm.entrypoints.api_server.app
Expand All @@ -33,7 +32,7 @@ def stats() -> Response:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser = AsyncEngineArgs.add_cli_args(parser)
Expand Down
34 changes: 26 additions & 8 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import dataclasses
import json
import sys
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
Expand All @@ -20,6 +21,24 @@ def nullable_str(val: str):
return val


class FlexibleArgumentParser(argparse.ArgumentParser):
"""ArgumentParser that allows both underscore and dash in names."""

def parse_args(self, args=None, namespace=None):
if args is None:
args = sys.argv[1:]

# Convert underscores to dashes in argument names
processed_args = []
for arg in args:
if arg.startswith('--'):
processed_args.append('--' + arg[2:].replace('_', '-'))
else:
processed_args.append(arg)

return super().parse_args(processed_args, namespace)


mgoin marked this conversation as resolved.
Show resolved Hide resolved
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
Expand Down Expand Up @@ -110,7 +129,7 @@ def __post_init__(self):

@staticmethod
def add_cli_args_for_vlm(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--image-input-type',
type=nullable_str,
default=None,
Expand Down Expand Up @@ -156,8 +175,7 @@ def add_cli_args_for_vlm(
return parser

@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Shared CLI arguments for vLLM engine."""

# Model arguments
Expand Down Expand Up @@ -800,8 +818,8 @@ class AsyncEngineArgs(EngineArgs):
max_log_len: Optional[int] = None

@staticmethod
def add_cli_args(parser: argparse.ArgumentParser,
async_args_only: bool = False) -> argparse.ArgumentParser:
def add_cli_args(parser: FlexibleArgumentParser,
async_args_only: bool = False) -> FlexibleArgumentParser:
if not async_args_only:
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray',
Expand All @@ -822,13 +840,13 @@ def add_cli_args(parser: argparse.ArgumentParser,

# These functions are used by sphinx to build the documentation
def _engine_args_parser():
return EngineArgs.add_cli_args(argparse.ArgumentParser())
return EngineArgs.add_cli_args(FlexibleArgumentParser())


def _async_engine_args_parser():
return AsyncEngineArgs.add_cli_args(argparse.ArgumentParser(),
return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(),
async_args_only=True)


def _vlm_engine_args_parser():
return EngineArgs.add_cli_args_for_vlm(argparse.ArgumentParser())
return EngineArgs.add_cli_args_for_vlm(FlexibleArgumentParser())
5 changes: 2 additions & 3 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
change `vllm/entrypoints/openai/api_server.py` instead.
"""

import argparse
import json
import ssl
from typing import AsyncGenerator
Expand All @@ -15,7 +14,7 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, FlexibleArgumentParser
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -80,7 +79,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]:


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = FlexibleArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
Expand Down
Loading
Loading