Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support W8A8 inference in vllm #1508

Closed
wants to merge 110 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
110 commits
Select commit Hold shift + click to select a range
e08acaa
add llama quant
Aug 14, 2023
387c804
change weight path
Aug 14, 2023
68cd1e0
fix weight load
Aug 15, 2023
ca088d6
merge gate and up matrix
Aug 16, 2023
6bde51e
use FTLlamaRMSNorm
Aug 17, 2023
931e51c
support bitsandbytes int8
Aug 28, 2023
c0c2a4d
llama support bnb 4bit
Aug 30, 2023
f677586
add int8gemm
Sep 20, 2023
5d79d03
support int8 inference
Sep 20, 2023
24ed816
Reduce alpha,beta unnecessary d2h
sleepcoo Sep 21, 2023
0e84a61
fix weight load
Sep 21, 2023
7baa3ac
fix weight load
Sep 22, 2023
7478e3c
fix ln layer init
Sep 22, 2023
5099364
rms norm fusion
Sep 26, 2023
b051d53
fix w8a8 linear
Sep 26, 2023
33796f4
use same scale across tensors
Sep 26, 2023
a73de71
add ftgemm
Sep 27, 2023
73caa70
fix cublas linear
Sep 27, 2023
3b7b967
clean cublass gemm code
Sep 27, 2023
6550983
code clean
Sep 27, 2023
27d806b
fuse dequant silu and quant
Sep 28, 2023
1abaf67
fuse dequant and add residual
Sep 28, 2023
7468658
fuse dequant, add residual, rms_norm and quant
Sep 28, 2023
d5d4fcd
fuse dequant and pos_encoding
Sep 28, 2023
9056fd0
setup for fused kernels
Sep 28, 2023
627dac8
fix bugs
Sep 28, 2023
388a215
add tests for fusion kernels
Oct 9, 2023
b08a5e3
fix uncontiguous tensor case
Oct 17, 2023
60484fb
add quant, dequant kernel
Oct 17, 2023
c2b5750
optimize layernorm kernel
Oct 17, 2023
8c60613
add python class DequantAddResidualI8RMSNormQuant, DequantPagedAttent…
Oct 17, 2023
99ad6e5
add tests
Oct 17, 2023
c0a5f3b
add w8a8linear without quant and dequant
Oct 17, 2023
415b0af
adjust code for fusion
Oct 17, 2023
1734727
rm obsolete file
Oct 18, 2023
b6587a3
fix llama
Oct 19, 2023
63df225
remove cutlass dependency
Oct 19, 2023
b44ebb0
add sq quantized linear
Oct 24, 2023
25e63fd
rm unit test for w8a8 linear
Oct 24, 2023
e0d3aa0
adjust i8 llama weight load
Oct 24, 2023
b6cdc76
add fusion.py
Oct 26, 2023
3b89a4d
code clean
Oct 30, 2023
ce271bc
support kv cache quantization
Sep 19, 2023
f8b0b05
fix python code
Sep 19, 2023
b1560db
merge and reformat
Sep 20, 2023
5c672ec
support generating kv quant parameters and evaluting kv quant models
Sep 27, 2023
f8d6b99
modify test functions
Sep 28, 2023
f8427e3
fix test code
Sep 28, 2023
df286fe
fix test attention
Sep 28, 2023
b2d9b8c
modify attention kernel test using pytest
Oct 12, 2023
c5a1a73
fix quant parameter passing
Oct 16, 2023
fbed95c
code clean
Oct 30, 2023
f396ed3
code clean
Oct 30, 2023
c76d864
simplify code
Oct 30, 2023
076f79d
simplify code
Oct 30, 2023
bc2038f
Merge branch 'main' into w8a8
AniZpZ Nov 2, 2023
ad8f950
Merge branch 'main' into kv_quant
AniZpZ Nov 2, 2023
ac24163
code format
Nov 3, 2023
1d256f6
code format
Nov 3, 2023
9db5a63
code format
Nov 3, 2023
2543722
code format
Nov 3, 2023
4226683
code format
Nov 3, 2023
df15d44
fix merge
Nov 15, 2023
872d156
fix reshape_and_cache_quantized
Nov 20, 2023
6d3ddd8
fix load weight bug
Nov 9, 2023
9ad3e5b
update w8a8 kernels
Nov 20, 2023
2de11d9
fix w8a8
Nov 21, 2023
3c25ce3
update kernel tests
Nov 21, 2023
9448d1b
fix merge problems
Nov 22, 2023
443d3e3
fix bugs
Nov 21, 2023
ca2a1f9
fix pos_emb bugs
Nov 22, 2023
d9163e2
fix merge problems
Nov 22, 2023
d08ba77
fix merge problems
Nov 22, 2023
8f9d9ef
fix w8a8
Nov 22, 2023
81b0dc8
code format
Nov 22, 2023
d607725
fix rotary_embedding bug
Nov 22, 2023
8c29013
tmp fix
Nov 22, 2023
8b5278d
tmp fix2
Nov 22, 2023
1edc0da
fix linear
Nov 23, 2023
17c426a
fix int8 linear bugs
Nov 23, 2023
d8a9d4a
update kv-quant kernels
Nov 23, 2023
0b06f96
add kv-quant kernel tests
Nov 23, 2023
734dcc6
support kv-quant
Nov 23, 2023
f2d3dac
Merge branch 'kv_quant' into vllmq
Nov 23, 2023
90a0be2
fix merge problems
Nov 24, 2023
11ea458
fix work bugs
Nov 24, 2023
d5ab474
fix kernel bugs
Dec 1, 2023
e12104f
support per-token quant tensor parallel
Dec 1, 2023
0a64392
fix tp load weight
Dec 1, 2023
bd48d23
fix work bugs
Nov 24, 2023
38bcfce
support per-token quant tensor parallel
Dec 1, 2023
7fba861
fix tp load weight
Dec 1, 2023
aa9dcfd
fix kv-quant args
Dec 5, 2023
10d73f4
fix I8CUGEMM duplicate instance issue
Dec 5, 2023
25a6837
fix I8CUGEMM duplicate instance issue
Dec 5, 2023
6c1ee44
Merge tag 'v0.2.3' into vllmq_0.2.3
Dec 7, 2023
28e3441
optimize kernels and dequant scale
Dec 12, 2023
858673f
fix dtype bug
Dec 13, 2023
5bbe52d
optimize kernels and dequant scale
Dec 12, 2023
7e6f203
fix kv-quant bugs
Dec 13, 2023
d6eeee8
fix attention params
Dec 18, 2023
6d09d83
merge utils.cuh to quant_utils.cuh
Jan 8, 2024
63fd0c8
fix bugs
Jan 8, 2024
7e8f65b
Merge tag 'v0.2.7' into vllmq_0.2.7
Jan 9, 2024
baac1a9
refactor smoothquant code
Jan 16, 2024
dcf8e8f
Merge remote-tracking branch 'origin/vllmq_merge' into w8a8_v0.2.7
Jan 16, 2024
db074fa
format code
Jan 16, 2024
b07173d
add .buildkite
Jan 16, 2024
b6606ab
fix quant config
Jan 31, 2024
b997633
fix tp bug
Jan 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions .buildkite/run-benchmarks.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# This script is run by buildkite to run the benchmarks and upload the results to buildkite

set -ex

# cd into parent directory of this file
cd "$(dirname "${BASH_SOURCE[0]}")/.."

# run benchmarks and upload the result to buildkite
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt

python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt

# write the results into a markdown file
echo "### Latency Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_latency.txt >> benchmark_results.md
echo "" >> benchmark_results.md
sed -n '$p' benchmark_latency.txt >> benchmark_results.md
echo "### Throughput Benchmarks" >> benchmark_results.md
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md
echo "" >> benchmark_results.md
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md

# upload the results to buildkite
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
41 changes: 41 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# In this file, you can add more tests to run either by adding a new step or
# adding a new command to an existing step. See different options here for examples.
# This script will be feed into Jinja template in `test-template.j2` to generate
# the final pipeline yaml file.

steps:
- label: Regression Test
command: pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional

- label: AsyncEngine Test
command: pytest -v -s async_engine

- label: Distributed Test
command: pytest -v -s test_comm_ops.py
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.

- label: Engine Test
command: pytest -v -s engine

- label: Kernels Test
command: pytest -v -s kernels
soft_fail: true

- label: Models Test
commands:
- pytest -v -s models --forked
soft_fail: true

- label: Samplers Test
command: pytest -v -s samplers --forked

- label: Worker Test
command: pytest -v -s worker

- label: Benchmarks
working_dir: "/vllm-workspace/.buildkite"
commands:
- pip install aiohttp
- bash run-benchmarks.sh
50 changes: 50 additions & 0 deletions .buildkite/test-template.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
{% set default_num_gpu = 1 %}
{% set default_working_dir = "/vllm-workspace/tests" %}

steps:
- label: ":docker: build image"
commands:
- "docker build --tag {{ docker_image }} --target test --progress plain ."
- "docker push {{ docker_image }}"
env:
DOCKER_BUILDKIT: "1"
- wait

{% for step in steps %}
- label: "{{ step.label }}"
agents:
queue: kubernetes
soft_fail: {{ step.soft_fail or false }}
retry:
automatic:
- exit_status: -1 # Agent was lost
limit: 5
plugins:
- kubernetes:
podSpec:
volumes:
- name: dshm
emptyDir:
medium: Memory
containers:
- image: "{{ docker_image }}"
command: ["bash"]
args:
- "-c"
- "'cd {{ (step.working_dir or default_working_dir) | safe }} && {{ step.command or (step.commands | join(' && ')) | safe }}'"
resources:
requests:
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
limits:
nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
env:
- name: HF_TOKEN
valueFrom:
secretKeyRef:
name: hf-token-secret
key: token
volumeMounts:
- mountPath: /dev/shm
name: dshm
{% endfor %}
2 changes: 1 addition & 1 deletion csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,4 @@ void gelu_fast(
torch::Tensor& input) // [..., d]
{
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
}
}
1 change: 1 addition & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_int8.cuh"
2 changes: 1 addition & 1 deletion csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -879,4 +879,4 @@ void paged_attention_v2(
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
#undef DIVIDE_ROUND_UP
8 changes: 8 additions & 0 deletions csrc/attention/dtype_float32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ inline __device__ float4 add(float4 a, float4 b) {
return c;
}

// for compiling, the above function seems to be useless
inline __device__ Float4_ add(Float4_ a, Float4_ b) {
Float4_ c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}

// Vector multiplication.
template<>
inline __device__ float mul<float, float>(float a, float b) {
Expand Down
49 changes: 49 additions & 0 deletions csrc/attention/dtype_int8.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include <stdint.h>
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

namespace vllm {
// define int8 vector types for quantization of kv cache

template<>
struct Vec<int8_t, 1> {
using Type = int8_t;
};

template<>
struct Vec<int8_t, 2> {
using Type = int16_t;
};

template<>
struct Vec<int8_t, 4> {
using Type = int32_t;
};

template<>
struct Vec<int8_t, 8> {
using Type = int64_t;
};

template<>
struct FloatVec<int8_t> {
using Type = float;
};

template<>
struct FloatVec<int16_t> {
using Type = float2;
};

template<>
struct FloatVec<int32_t> {
using Type = Float4_;
};

template<>
struct FloatVec<int64_t> {
using Type = Float8_;
};
}
2 changes: 1 addition & 1 deletion csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,4 @@ void gather_cached_kv(
block_size,
x);
});
}
}
11 changes: 10 additions & 1 deletion csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,17 @@
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
// AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)

#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
2 changes: 1 addition & 1 deletion csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ void fused_add_rms_norm(
num_tokens,
hidden_size);
});
}
}
92 changes: 92 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,95 @@ torch::Tensor gptq_gemm(
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);

// These are kernels used by smoothquant
void rms_norm_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);

void dequant_add_residual_rms_norm_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& gamma,
float scale,
float epsilon);

void dequant_add_residual_rms_norm_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& gamma,
torch::Tensor& scale,
float epsilon,
float weight_dequant_scale);

void add_residual_rms_norm_quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
float epsilon);

void dequant_rotary_embedding(
torch::Tensor& positions,
torch::Tensor& query,
torch::Tensor& key,
int head_size,
torch::Tensor& cos_sin_cache,
bool is_neox,
torch::Tensor& query_out,
torch::Tensor& key_out,
float query_scale,
float key_scale);

void dequant_silu_and_mul_quant(
torch::Tensor& out,
torch::Tensor& input,
float gate_scale,
float up_scale,
float out_scale);

void dequant_silu_and_mul_quant(
torch::Tensor& out,
torch::Tensor& input,
float gate_scale,
float up_scale,
torch::Tensor& out_scale,
torch::Tensor& tmp);

void dequant_add_residual(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
float scale);

void dequant_add_residual(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& scale,
float weight_dequant_scale);

void dequant(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void dequant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale,
float weight_dequant_scale);

void quant(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);
2 changes: 1 addition & 1 deletion csrc/pos_encoding_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ void rotary_embedding(
head_size);
}
});
}
}
Loading