-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Add support for MPT #334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add support for MPT #334
Changes from all commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
218bbb5
Add BLOOM without ALiBi
WoosukKwon a65960b
Minor
WoosukKwon dc45776
Add BLOOM to supported models
WoosukKwon b2f1167
Inheritance -> Composition
WoosukKwon cca8695
Add ALiBi bias to attention kernel
WoosukKwon b691559
Add PagedAttentionWithALiBi
WoosukKwon 8aac4ed
Fix BLOOM
WoosukKwon 18d54e6
[Minor] Fix comment
WoosukKwon 093bfcd
[Minor] single quote -> double quote
WoosukKwon 047a0be
[Minor] Add more comments
WoosukKwon 2012b63
Add head size 112
WoosukKwon 6d22312
Add MPTConfig
WoosukKwon 9918d82
Add get_config
WoosukKwon 7205ebd
AutoConfig -> get_config
WoosukKwon 35831c4
Implement MPT
WoosukKwon 3a30e17
Register MPT
WoosukKwon cff169c
Add MPT to supported models
WoosukKwon 504cbc0
Merge branch 'main' into mpt
WoosukKwon 9ca9dc6
yapf
WoosukKwon d3704e0
Remove init_config
WoosukKwon cbb47cb
single quote -> double quote
WoosukKwon 1488f90
Minor
WoosukKwon 8046433
Minor
WoosukKwon 941ca40
Minor
WoosukKwon c29f679
Minor
WoosukKwon 1573972
yapf
WoosukKwon 6a7de44
Merge branch 'main' into mpt
WoosukKwon 42adfa5
yapf
WoosukKwon e745990
Fix
WoosukKwon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main | ||
import math | ||
from typing import Dict, List, Optional, Tuple | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.model_executor.layers.activation import get_act_fn | ||
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi | ||
from vllm.model_executor.layers.sampler import Sampler | ||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator, | ||
load_tensor_parallel_weights) | ||
from vllm.model_executor.parallel_utils.parallel_state import ( | ||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | ||
from vllm.model_executor.parallel_utils.tensor_parallel import ( | ||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) | ||
from vllm.sequence import SequenceOutputs | ||
from vllm.transformers_utils.configs.mpt import MPTConfig | ||
|
||
KVCache = Tuple[torch.Tensor, torch.Tensor] | ||
|
||
|
||
def _get_alibi_slopes( | ||
total_num_heads: int, | ||
alibi_bias_max: int, | ||
) -> torch.Tensor: | ||
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) | ||
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) | ||
m = m.mul(alibi_bias_max / next_power_of_2) | ||
slopes = 1.0 / torch.pow(2, m) | ||
if next_power_of_2 != total_num_heads: | ||
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads] | ||
return slopes | ||
|
||
|
||
class MPTAttention(nn.Module): | ||
|
||
def __init__(self, config: MPTConfig): | ||
super().__init__() | ||
self.d_model = config.d_model | ||
self.total_num_heads = config.n_heads | ||
self.clip_qkv = config.attn_config["clip_qkv"] | ||
self.qk_ln = config.attn_config["qk_ln"] | ||
self.alibi_bias_max = config.attn_config["alibi_bias_max"] | ||
assert not config.attn_config["prefix_lm"] | ||
assert config.attn_config["alibi"] | ||
|
||
self.qkv_proj = ColumnParallelLinear( | ||
self.d_model, | ||
3 * self.d_model, | ||
bias=not config.no_bias, | ||
gather_output=False, | ||
perform_initialization=False, | ||
) | ||
if self.qk_ln: | ||
self.q_ln = nn.LayerNorm(self.d_model) | ||
self.k_ln = nn.LayerNorm(self.d_model) | ||
self.out_proj = RowParallelLinear( | ||
self.d_model, | ||
self.d_model, | ||
bias=not config.no_bias, | ||
input_is_parallel=True, | ||
perform_initialization=False, | ||
) | ||
|
||
tp_world_size = get_tensor_model_parallel_world_size() | ||
assert self.total_num_heads % tp_world_size == 0 | ||
self.num_heads = self.total_num_heads // tp_world_size | ||
|
||
# Create the alibi slopes and slice them. | ||
tp_rank = get_tensor_model_parallel_rank() | ||
head_start = tp_rank * self.num_heads | ||
head_end = (tp_rank + 1) * self.num_heads | ||
alibi_slopes = _get_alibi_slopes(self.total_num_heads, | ||
self.alibi_bias_max) | ||
alibi_slopes = alibi_slopes[head_start:head_end].tolist() | ||
|
||
self.head_dim = self.d_model // self.total_num_heads | ||
scaling = self.head_dim**-0.5 | ||
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim, | ||
scaling, alibi_slopes) | ||
|
||
def forward( | ||
self, | ||
position_ids: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: KVCache, | ||
input_metadata: InputMetadata, | ||
cache_event: Optional[torch.cuda.Event], | ||
) -> torch.Tensor: | ||
del position_ids # unused. | ||
qkv, _ = self.qkv_proj(hidden_states) | ||
if self.clip_qkv is not None: | ||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) | ||
q, k, v = qkv.chunk(chunks=3, dim=-1) | ||
if self.qk_ln: | ||
q = self.q_ln(q) | ||
k = self.k_ln(k) | ||
k_cache, v_cache = kv_cache | ||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, | ||
cache_event) | ||
output, _ = self.out_proj(attn_output) | ||
return output | ||
|
||
|
||
class MPTMLP(nn.Module): | ||
|
||
def __init__(self, config: MPTConfig): | ||
super().__init__() | ||
hidden_size = config.d_model | ||
expansion_ratio = config.expansion_ratio | ||
intermediate_size = expansion_ratio * hidden_size | ||
self.up_proj = ColumnParallelLinear(hidden_size, | ||
intermediate_size, | ||
bias=not config.no_bias, | ||
gather_output=False, | ||
perform_initialization=False) | ||
self.act = get_act_fn("gelu") | ||
self.down_proj = RowParallelLinear(intermediate_size, | ||
hidden_size, | ||
bias=not config.no_bias, | ||
input_is_parallel=True, | ||
perform_initialization=False) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x, _ = self.up_proj(x) | ||
x = self.act(x) | ||
x, _ = self.down_proj(x) | ||
return x | ||
|
||
|
||
class MPTBlock(nn.Module): | ||
|
||
def __init__(self, config: MPTConfig): | ||
super().__init__() | ||
hidden_size = config.d_model | ||
self.norm_1 = nn.LayerNorm(hidden_size) | ||
self.attn = MPTAttention(config) | ||
self.norm_2 = nn.LayerNorm(hidden_size) | ||
self.ffn = MPTMLP(config) | ||
|
||
def forward( | ||
self, | ||
position_ids: torch.Tensor, | ||
hidden_states: torch.Tensor, | ||
kv_cache: KVCache, | ||
input_metadata: InputMetadata, | ||
cache_event: Optional[torch.cuda.Event], | ||
) -> torch.Tensor: | ||
x = self.norm_1(hidden_states) | ||
x = self.attn( | ||
position_ids=position_ids, | ||
hidden_states=x, | ||
kv_cache=kv_cache, | ||
input_metadata=input_metadata, | ||
cache_event=cache_event, | ||
) | ||
hidden_states = hidden_states + x | ||
x = self.norm_2(hidden_states) | ||
x = self.ffn(x) | ||
hidden_states = hidden_states + x | ||
return hidden_states | ||
|
||
|
||
class MPTModel(nn.Module): | ||
|
||
def __init__(self, config: MPTConfig): | ||
super().__init__() | ||
assert config.embedding_fraction == 1.0 | ||
assert config.norm_type == "low_precision_layernorm" | ||
|
||
self.wte = VocabParallelEmbedding(config.vocab_size, | ||
config.d_model, | ||
perform_initialization=False) | ||
self.blocks = nn.ModuleList( | ||
[MPTBlock(config) for _ in range(config.n_layers)]) | ||
self.norm_f = nn.LayerNorm(config.d_model) | ||
if config.no_bias: | ||
for module in self.modules(): | ||
if hasattr(module, "bias"): | ||
if isinstance(module.bias, nn.Parameter): | ||
# Remove the bias term in Linear and LayerNorm. | ||
module.register_parameter("bias", None) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
position_ids: torch.Tensor, | ||
kv_caches: List[KVCache], | ||
input_metadata: InputMetadata, | ||
cache_events: Optional[List[torch.cuda.Event]], | ||
) -> torch.Tensor: | ||
hidden_states = self.wte(input_ids) | ||
for i in range(len(self.blocks)): | ||
if cache_events is None: | ||
cache_event = None | ||
else: | ||
cache_event = cache_events[i] | ||
block = self.blocks[i] | ||
hidden_states = block( | ||
position_ids, | ||
hidden_states, | ||
kv_caches[i], | ||
input_metadata, | ||
cache_event, | ||
) | ||
hidden_states = self.norm_f(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class MPTForCausalLM(nn.Module): | ||
|
||
def __init__(self, config: MPTConfig): | ||
super().__init__() | ||
self.config = config | ||
assert config.tie_word_embeddings | ||
|
||
self.transformer = MPTModel(config) | ||
# TODO(zhuohan): create a new weight after implementing pipeline | ||
# parallelism | ||
self.lm_head_weight = self.transformer.wte.weight | ||
self.sampler = Sampler(config.vocab_size) | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.Tensor, | ||
positions: torch.Tensor, | ||
kv_caches: List[KVCache], | ||
input_metadata: InputMetadata, | ||
cache_events: Optional[List[torch.cuda.Event]], | ||
) -> Dict[int, SequenceOutputs]: | ||
hidden_states = self.transformer(input_ids, positions, kv_caches, | ||
input_metadata, cache_events) | ||
next_tokens = self.sampler(self.lm_head_weight, hidden_states, | ||
input_metadata) | ||
return next_tokens | ||
|
||
_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"] | ||
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"] | ||
|
||
def load_weights(self, | ||
model_name_or_path: str, | ||
cache_dir: Optional[str] = None, | ||
use_np_cache: bool = False): | ||
tp_world_size = get_tensor_model_parallel_world_size() | ||
tp_rank = get_tensor_model_parallel_rank() | ||
state_dict = self.state_dict() | ||
for name, loaded_weight in hf_model_weights_iterator( | ||
model_name_or_path, cache_dir, use_np_cache): | ||
if "Wqkv" in name: | ||
# NOTE(woosuk): MPT's fused QKV has the shape of | ||
# [3 * num_heads * head_size, hidden_size]. | ||
# When tensor model parallelism is used, we need to shard | ||
# the weight along the hidden dimension. | ||
total_num_heads = self.config.num_attention_heads | ||
hidden_size = self.config.hidden_size | ||
head_size = hidden_size // total_num_heads | ||
num_heads = total_num_heads // tp_world_size | ||
head_start = tp_rank * num_heads | ||
head_end = (tp_rank + 1) * num_heads | ||
|
||
if name.endswith(".weight"): | ||
loaded_weight = loaded_weight.view(3, total_num_heads, | ||
head_size, hidden_size) | ||
loaded_weight = loaded_weight[:, head_start:head_end, :, :] | ||
loaded_weight = loaded_weight.reshape(-1, hidden_size) | ||
elif name.endswith(".bias"): | ||
loaded_weight = loaded_weight.view(3, total_num_heads, | ||
head_size) | ||
loaded_weight = loaded_weight[:, head_start:head_end, :] | ||
loaded_weight = loaded_weight.reshape(-1) | ||
else: | ||
raise ValueError(f"Unexpected parameter name {name}") | ||
name = name.replace("Wqkv", "qkv_proj") | ||
param = state_dict[name] | ||
load_tensor_parallel_weights(param, loaded_weight, name, | ||
self._column_parallel_weights, | ||
self._row_parallel_weights, tp_rank) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from transformers import AutoConfig, PretrainedConfig | ||
|
||
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import | ||
|
||
_CONFIG_REGISTRY = { | ||
"mpt": MPTConfig, | ||
} | ||
|
||
|
||
def get_config(model: str) -> PretrainedConfig: | ||
config = AutoConfig.from_pretrained(model, trust_remote_code=True) | ||
if config.model_type in _CONFIG_REGISTRY: | ||
config_class = _CONFIG_REGISTRY[config.model_type] | ||
config = config_class.from_pretrained(model) | ||
return config | ||
zhuohan123 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vllm.transformers_utils.configs.mpt import MPTConfig | ||
|
||
__all__ = [ | ||
"MPTConfig", | ||
] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.