Skip to content

Commit

Permalink
Support multi-gpu inference
Browse files Browse the repository at this point in the history
  • Loading branch information
research4pan committed Jan 11, 2024
1 parent f49123d commit adf825a
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 138 deletions.
4 changes: 2 additions & 2 deletions configs/accelerator_multigpu_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
main_process_port: 11000
main_process_port: 11002
11 changes: 3 additions & 8 deletions configs/ds_config_zero3.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,16 @@
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
"device": "cpu"
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_max_live_parameters": 2e10,
"stage3_max_reuse_distance": 2e10,
"stage3_gather_16bit_weights_on_model_save": true
},

Expand Down
3 changes: 2 additions & 1 deletion examples/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def main():
tune_strategy='none',
ds_config=ds_config,
device=pipeline_args.device,
use_accelerator=True,
)

# We don't need input data, we will read interactively from stdin
Expand Down Expand Up @@ -120,7 +121,7 @@ def main():
print("Bot: ", end="")
print_index = 0

token_per_step = 4
token_per_step = 100

for response, flag_break in inferencer.stream_inference(
context=context,
Expand Down
15 changes: 10 additions & 5 deletions scripts/run_chatbot.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ if [ $# -ge 2 ]; then
lora_args="--lora_model_path $2"
fi

CUDA_VISIBLE_DEVICES=0 \
deepspeed examples/chatbot.py \
--deepspeed configs/ds_config_chatbot.json \
--model_name_or_path ${model} \
${lora_args}
# --temperature 0.7 \
accelerate launch --config_file configs/accelerator_multigpu_config.yaml \
examples/chatbot.py \
--deepspeed configs/ds_config_chatbot.json \
--model_name_or_path ${model} \
--use_accelerator True \
--max_new_tokens 256 \
--temperature 1.0 \
--end_string "#" \
${lora_args}
4 changes: 4 additions & 0 deletions src/lmflow/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,10 @@ class InferencerArguments:
"help": "whether turn on true random sampling during inference."
},
)
use_accelerator: bool = field(
default=False, metadata={"help": "Whether to use Huggingface Accelerator instead of Deepspeed"},
)


@dataclass
class RaftAlignerArguments(TrainingArguments):
Expand Down
1 change: 0 additions & 1 deletion src/lmflow/models/hf_decoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def __init__(
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code = model_args.trust_remote_code,
)
#for deepspeed zero3, we don't need to specify device_map
Expand Down
37 changes: 29 additions & 8 deletions src/lmflow/pipeline/inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Dict, List
from concurrent.futures import ThreadPoolExecutor

from accelerate import Accelerator
from transformers import AutoConfig
import torch.distributed as dist
import torch.nn.functional as F
Expand Down Expand Up @@ -81,6 +82,10 @@ def __init__(self, model_args, data_args, inferencer_args):
print("Error in setting hidden size, use the default size 1024")
self.model_hidden_size = 1024 # gpt2 seems do not have hidden_size in config

if inferencer_args.use_accelerator:
self.accelerator = Accelerator()
self.accelerator.wait_for_everyone()


def create_dataloader(self, dataset: Dataset):
r"""Batchlize dataset and format it to dataloader.
Expand Down Expand Up @@ -161,7 +166,7 @@ def inference(
input = current_batch['input']
input['text'] = prompt_structure.format(input=input['text'])

if 'images' in input and isinstance(input['images'], list):
if False and 'images' in input and isinstance(input['images'], list):
input['images'] = np.array(input['images'])
if remove_image_flag:
# remove the image flag <ImageHere> in tokenization;
Expand Down Expand Up @@ -219,17 +224,33 @@ def inference(
raise NotImplementedError(
f"device \"{self.inferencer_args.device}\" is not supported"
)

if self.inferencer_args.use_accelerator:
inputs = inputs.to(self.accelerator.device)


if remove_image_flag:
inputs["image_token_indexes"] = image_token_indexes
inputs["one_sample_multiple_images"] = True

outputs = model.inference(
inputs,
max_new_tokens=max_new_tokens,
temperature=self.inferencer_args.temperature,
repetition_penalty=self.inferencer_args.repetition_penalty,
do_sample=self.inferencer_args.do_sample,
)
if self.inferencer_args.use_accelerator:
with self.accelerator.autocast():
outputs = model.inference(
inputs,
max_new_tokens=max_new_tokens,
temperature=self.inferencer_args.temperature,
repetition_penalty=self.inferencer_args.repetition_penalty,
do_sample=self.inferencer_args.do_sample,
use_accelerator=True,
)
else:
outputs = model.inference(
inputs,
max_new_tokens=max_new_tokens,
temperature=self.inferencer_args.temperature,
repetition_penalty=self.inferencer_args.repetition_penalty,
do_sample=self.inferencer_args.do_sample,
)

# only return the generation, trucating the input
if self.model_args.arch_type != "vision_encoder_decoder":
Expand Down
160 changes: 47 additions & 113 deletions src/lmflow/utils/flash_attention/llama_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

#try to import flash_attn 2.x.x, if not, import flash_attn 1.x.x
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_func
except:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func as flash_attn_func

from flash_attn.bert_padding import unpad_input, pad_input

Expand All @@ -22,136 +22,70 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()

query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
if self.config.pretraining_tp > 1:
raise ValueError("pretraining_tp > 1 is not supported for flash attention")
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
# assert past_key_value is None, "past_key_value is not supported"

if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]

cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
# assert not output_attentions, "output_attentions is not supported"
# assert not use_cache, "use_cache is not supported"
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)

past_key_value = (key_states, value_states) if use_cache else None

# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py

# transform the data into the format required by flash attention
# import pdb; pdb.set_trace()

if query_states.shape[-2] == 1 or query_states.shape[-2] != key_states.shape[-2]:
# decode token-by-token, do not use flash attention
# in incremental state, do not use flash attention
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, past_key_value

# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask

if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
query_states, key_states, value_states = [
rearrange(x, "b h s d -> b s h d") for x in [query_states, key_states, value_states]
]

input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype

query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)

# below output will have shape (batch_size, seqlen, nheads, headdim)
attn_output = flash_attn_func(query_states, key_states, value_states, causal=True)

if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
f" {attn_output.size()}"
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, past_key_value

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if output_attentions:
raise NotImplementedError("`output_attentions` is not supported when `use_flash_attn` is True")
attn_weights = None

return attn_output, attn_weights, past_key_value


# Disable the transformation of the attention mask in LlamaModel as the flash attention
Expand Down

0 comments on commit adf825a

Please sign in to comment.