Skip to content

[Quality] Add CI for formatting #343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: pylint

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main

jobs:
pylint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pylint==2.8.2
- name: Analysing the code with pylint
run: |
pylint vllm
31 changes: 31 additions & 0 deletions .github/workflows/yapf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: yapf

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main
jobs:
yapf:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
8 changes: 8 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from typing import Dict, List, Optional

from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
Expand Down Expand Up @@ -206,6 +207,13 @@ async def abort(self, request_id: str) -> None:
self.is_engine_running = False
self.kicking_request_id = None

async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote()
else:
return self.engine.get_model_config()

@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def abort_request(self, request_id: str) -> None:
"""
self.scheduler.abort_seq_group(request_id)

def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
Expand Down
67 changes: 37 additions & 30 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,30 @@
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py

import argparse
import asyncio
from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
from typing import AsyncGenerator, Dict, List, Optional

import fastapi
from fastapi import BackgroundTasks, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from fastchat.conversation import (Conversation, SeparatorStyle,
get_conv_template)
import uvicorn

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
CompletionRequest, CompletionResponse, CompletionResponseChoice,
CompletionResponseStreamChoice, CompletionStreamResponse,
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
ModelCard, ModelList, ModelPermission, UsageInfo)
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
return prompt


async def check_length(request, prompt, engine):
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
context_len = engine.engine.model_config.hf_config.max_sequence_length
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
context_len = engine.engine.model_config.hf_config.max_position_embeddings
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
context_len = engine.engine.model_config.hf_config.seq_length
async def check_length(request, prompt, model_config):
if hasattr(model_config.hf_config, "max_sequence_length"):
context_len = model_config.hf_config.max_sequence_length
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
elif hasattr(model_config.hf_config, "max_position_embeddings"):
context_len = model_config.hf_config.max_position_embeddings
elif hasattr(model_config.hf_config, "seq_length"):
context_len = model_config.hf_config.seq_length
else:
context_len = 2048

Expand Down Expand Up @@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported")

prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt, engine)
error_check_ret = await check_length(request, prompt, engine_model_config)
if error_check_ret is not None:
return error_check_ret

Expand All @@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

result_generator = engine.generate(prompt, sampling_params,
request_id)
result_generator = engine.generate(prompt, sampling_params, request_id)

async def abort_request() -> None:
await engine.abort(request_id)

def create_stream_response_json(index: int,
text: str,
finish_reason: Optional[str] = None) -> str:
def create_stream_response_json(
index: int,
text: str,
finish_reason: Optional[str] = None,
) -> str:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=text),
Expand All @@ -238,10 +241,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=request_id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
chunk = ChatCompletionStreamResponse(id=request_id,
choices=[choice_data],
model=model_name)
data = chunk.json(exclude_unset=True, ensure_ascii=False)
yield f"data: {data}\n\n"

previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
Expand Down Expand Up @@ -295,8 +299,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
choices.append(choice_data)

num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
Expand All @@ -314,9 +318,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)

async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"

return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")

Expand Down Expand Up @@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
return create_error_response(HTTPStatus.BAD_REQUEST,
"please provide at least one prompt")
if len(request.prompt) > 1:
return create_error_response(HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not "
"currently supported")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"multiple prompts in a batch is not currently supported")
prompt = request.prompt[0]
else:
prompt = request.prompt
Expand Down Expand Up @@ -571,6 +577,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:

engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
engine_model_config = asyncio.run(engine.get_model_config())

# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(engine_args.tokenizer,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/bloom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# coding=utf-8
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
# Copyright 2023 The CacheFlow team.
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
Expand Down