Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][kernel] tensor parallelism with bitsandbytes quantization #8434

Merged
merged 5 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions tests/quantization/test_bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name)


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason='Test requires at least 2 GPUs.')
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
@fork_new_process_for_each_test
def test_load_tp_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
model_name, description) -> None:

hf_model_kwargs = {"load_in_4bit": True}
validate_generated_texts(hf_runner,
vllm_runner,
example_prompts[:1],
model_name,
hf_model_kwargs,
vllm_tp_size=2)


def log_generated_texts(prompts, outputs, runner_name):
logged_texts = []
for i, (_, generated_text) in enumerate(outputs):
Expand All @@ -80,22 +98,21 @@ def validate_generated_texts(hf_runner,
vllm_runner,
prompts,
model_name,
hf_model_kwargs=None):
hf_model_kwargs=None,
vllm_tp_size=1):

# NOTE: run vLLM first, as it requires a clean process
# when using distributed inference

#Run with vLLM runner
with vllm_runner(model_name,
quantization='bitsandbytes',
load_format='bitsandbytes',
tensor_parallel_size=vllm_tp_size,
enforce_eager=True,
gpu_memory_utilization=0.8) as llm:
vllm_outputs = llm.generate_greedy(prompts, 8)
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

Expand All @@ -108,7 +125,6 @@ def validate_generated_texts(hf_runner,
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")

# Clean up the GPU memory for the next test
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

Expand Down
6 changes: 0 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,6 @@ def verify_with_parallel_config(
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")

if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")

# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
Expand Down
21 changes: 16 additions & 5 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,11 @@ def weight_loader(self,
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -899,8 +902,13 @@ def weight_loader(self,
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)

# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
Expand Down Expand Up @@ -1000,6 +1008,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)

# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
Expand All @@ -1015,7 +1024,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
Expand Down
44 changes: 43 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, MultiModalConfig,
ParallelConfig, SchedulerConfig)
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
Expand Down Expand Up @@ -689,6 +691,8 @@ def save_model(
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""

# TODO: these module names are for Llama only,
# change so that it works with other models as well
default_target_modules = [
"gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj",
"o_proj"
Expand Down Expand Up @@ -911,13 +915,44 @@ def _parse_quant_state(param_name: str,
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()

for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if any(target_module in weight_name
for target_module in self.target_modules):
weight_name = weight_name.replace(".weight", ".qweight")

# weight partitions of different modules occur at
# different dimensions
# TODO: these module names are for Llama only,
# change so that it works with other models as well
if 'down_proj' in weight_name or 'o_proj' in weight_name:
chenqianfzh marked this conversation as resolved.
Show resolved Hide resolved
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]

else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index,
...]

# bitsandbytes requires data in GPU
loaded_weight = weight_tensor.cuda().data
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()

# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()

with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
Expand Down Expand Up @@ -958,6 +993,13 @@ def _load_weights(self, model_config: ModelConfig,
f"BitsAndBytes loader does not support {quant_method} "
"quantization")

# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with TP is not supported."
"Please try with PP.")

load_8bit = False
if pre_quant:
load_8bit = quant_config.get('load_in_8bit', False)
Expand Down
Loading