Skip to content

Commit

Permalink
[Quality] Add code formatter and linter (vllm-project#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored Jul 3, 2023
1 parent 0ffded8 commit d6fa1be
Show file tree
Hide file tree
Showing 47 changed files with 1,549 additions and 619 deletions.
434 changes: 434 additions & 0 deletions .pylintrc

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi

In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).

We include a formatting script [`format.sh`](./format.sh) to format the code.

### Pull Requests

When submitting a pull request:

1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Include a detailed description of the changes in the pull request.
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
3. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description.

Expand Down
7 changes: 5 additions & 2 deletions examples/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
print(LINE_UP, end=LINE_CLEAR, flush=True)


def post_http_request(prompt: str, api_url: str, n: int = 1,
def post_http_request(prompt: str,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
Expand All @@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,


def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
Expand Down
25 changes: 16 additions & 9 deletions examples/gradio_webserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@ def http_bot(prompt):
"stream": True,
"max_tokens": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)

for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
response = requests.post(args.model_url,
headers=headers,
json=pload,
stream=True)

for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
Expand All @@ -23,11 +28,11 @@ def http_bot(prompt):

def build_demo():
with gr.Blocks() as demo:
gr.Markdown(
"# vLLM text completion demo\n"
)
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
gr.Markdown("# vLLM text completion demo\n")
inputbox = gr.Textbox(label="Input",
placeholder="Enter text and press ENTER")
outputbox = gr.Textbox(label="Output",
placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo

Expand All @@ -36,7 +41,9 @@ def build_demo():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8001)
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
parser.add_argument("--model-url",
type=str,
default="http://localhost:8000/generate")
args = parser.parse_args()

demo = build_demo()
Expand Down
9 changes: 7 additions & 2 deletions examples/llm_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ def main(args: argparse.Namespace):
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
]

# Run the engine by calling `engine.step()` manually.
Expand Down
1 change: 0 additions & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from vllm import LLM, SamplingParams


# Sample prompts.
prompts = [
"Hello, my name is",
Expand Down
9 changes: 7 additions & 2 deletions examples/openai_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
# Test completion API
stream = True
completion = openai.Completion.create(
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
best_of=3, stream=stream, logprobs=3)
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
best_of=3,
stream=stream,
logprobs=3)

# print the completion
if stream:
Expand Down
108 changes: 108 additions & 0 deletions format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#!/usr/bin/env bash
# YAPF formatter, adapted from ray and skypilot.
#
# Usage:
# # Do work and commit your work.

# # Format files that differ from origin/main.
# bash format.sh

# # Commit changed files with message 'Run yapf and pylint'
#
#
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
# You are encouraged to run this locally before pushing changes for review.

# Cause the script to exit if a single command fails
set -eo pipefail

# this stops git rev-parse from failing if we run this from the .git directory
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
ROOT="$(git rev-parse --show-toplevel)"
builtin cd "$ROOT" || exit 1

YAPF_VERSION=$(yapf --version | awk '{print $2}')
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
MYPY_VERSION=$(mypy --version | awk '{print $2}')

# # params: tool name, tool version, required version
tool_version_check() {
if [[ $2 != $3 ]]; then
echo "Wrong $1 version installed: $3 is required, not $2."
exit 1
fi
}

tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"

YAPF_FLAGS=(
'--recursive'
'--parallel'
)

YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)

# Format specified files
format() {
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
}

# Format files that differ from main branch. Ignores dirs that are not slated
# for autoformat yet.
format_changed() {
# The `if` guard ensures that the list of filenames is not empty, which
# could cause yapf to receive 0 positional arguments, making it hang
# waiting for STDIN.
#
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
# exist on both branches.
MERGEBASE="$(git merge-base origin/main HEAD)"

if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
fi

}

# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
}

## This flag formats individual files. --files *must* be the first command line
## arg to use this option.
if [[ "$1" == '--files' ]]; then
format "${@:2}"
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is formatted.
elif [[ "$1" == '--all' ]]; then
format_all
else
# Format only the files that changed in last commit.
format_changed
fi
echo 'vLLM yapf: Done'

# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy

# Run Pylint
echo 'vLLM Pylint:'
pylint vllm

if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'
echo 'Changes not staged for commit:'
echo
git --no-pager diff --name-only

exit 1
fi
12 changes: 11 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,12 @@
mypy
# formatting
yapf==0.32.0
pylint==2.8.2

# type checking
mypy==0.991
types-PyYAML
types-requests
types-setuptools

# testing
pytest
53 changes: 33 additions & 20 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention(
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)

scale = 1.0 / (head_size ** 0.5)
scale = 1.0 / (head_size**0.5)
out = ref_masked_attention(q, keys, values, scale)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)
Expand All @@ -74,7 +74,7 @@ def ref_multi_query_kv_attention(
dtype: torch.dtype,
) -> torch.Tensor:
head_size = query.shape[-1]
scale = 1.0 / (head_size ** 0.5)
scale = 1.0 / (head_size**0.5)

num_seqs = len(cu_seq_lens) - 1
ref_outputs = []
Expand All @@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
seq_len = end_idx - start_idx

# Create attention mask.
attn_mask = torch.triu(
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda')

Expand Down Expand Up @@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention(
num_heads = value_cache.shape[1]
head_size = value_cache.shape[2]
block_size = value_cache.shape[3]
scale = 1.0 / (head_size ** 0.5)
scale = 1.0 / (head_size**0.5)

num_queries = len(cu_query_lens) - 1
ref_outputs = []
Expand All @@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
block_table = block_tables[i]

# Create attention mask
attn_mask = torch.triu(
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
attn_mask = torch.triu(torch.ones(query_len, context_len),
diagonal=context_len - query_len + 1) * -1e5
attn_mask = attn_mask.to(dtype=dtype, device='cuda')

keys = []
Expand Down Expand Up @@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention(
num_blocks: int,
dtype: torch.dtype,
) -> None:
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
qkv = torch.empty(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, _, _ = qkv.unbind(dim=1)

x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty(
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
dtype=dtype,
device='cuda')
key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty(
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
dtype=dtype,
device='cuda')
value_cache.uniform_(-1e-3, 1e-3)

context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')

Expand All @@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention(
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')

scale = float(1.0 / (head_size ** 0.5))
output = torch.empty(
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
scale = float(1.0 / (head_size**0.5))
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
attention_ops.single_query_cached_kv_attention(
output,
query,
Expand Down Expand Up @@ -235,9 +244,13 @@ def run_multi_query_kv_attention(
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
num_tokens = sum(seq_lens)

scale = float(1.0 / (head_size ** 0.5))
qkv = torch.empty(
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
scale = float(1.0 / (head_size**0.5))
qkv = torch.empty(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)

Expand Down
Loading

0 comments on commit d6fa1be

Please sign in to comment.