Skip to content

Add support for MPT #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
- LLaMA (`lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)

Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
Expand Down
3 changes: 3 additions & 0 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ void single_query_cached_kv_attention_launcher(
case 96:
LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS);
break;
case 112:
LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS);
break;
case 128:
LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS);
break;
Expand Down
3 changes: 3 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`LlamaForCausalLM`
- LLaMA, Vicuna, Alpaca, Koala, Guanaco
- :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, :code:`JosephusCheung/Guanaco`, etc.
* - :code: `MPTForCausalLM`
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
* - :code:`OPTForCausalLM`
- OPT, OPT-IML
- :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
Expand Down
5 changes: 3 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Optional

import torch
from transformers import AutoConfig, PretrainedConfig
from transformers import PretrainedConfig

from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory

logger = init_logger(__name__)
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(
self.use_dummy_weights = use_dummy_weights
self.seed = seed

self.hf_config: PretrainedConfig = AutoConfig.from_pretrained(model)
self.hf_config = get_config(model)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm import pos_encoding_ops
from vllm.model_executor.input_metadata import InputMetadata

_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128]


class PagedAttention(nn.Module):
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
}

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
from vllm.model_executor.models.llama import LlamaForCausalLM
from vllm.model_executor.models.mpt import MPTForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM

__all__ = [
Expand All @@ -11,5 +12,6 @@
"GPTBigCodeForCausalLM",
"GPTNeoXForCausalLM",
"LlamaForCausalLM",
"MPTForCausalLM",
"OPTForCausalLM",
]
279 changes: 279 additions & 0 deletions vllm/model_executor/models/mpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import math
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithALiBi
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.sequence import SequenceOutputs
from vllm.transformers_utils.configs.mpt import MPTConfig

KVCache = Tuple[torch.Tensor, torch.Tensor]


def _get_alibi_slopes(
total_num_heads: int,
alibi_bias_max: int,
) -> torch.Tensor:
next_power_of_2 = 2**math.ceil(math.log2(total_num_heads))
m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / next_power_of_2)
slopes = 1.0 / torch.pow(2, m)
if next_power_of_2 != total_num_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads]
return slopes


class MPTAttention(nn.Module):

def __init__(self, config: MPTConfig):
super().__init__()
self.d_model = config.d_model
self.total_num_heads = config.n_heads
self.clip_qkv = config.attn_config["clip_qkv"]
self.qk_ln = config.attn_config["qk_ln"]
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
assert not config.attn_config["prefix_lm"]
assert config.attn_config["alibi"]

self.qkv_proj = ColumnParallelLinear(
self.d_model,
3 * self.d_model,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
self.k_ln = nn.LayerNorm(self.d_model)
self.out_proj = RowParallelLinear(
self.d_model,
self.d_model,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False,
)

tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
self.num_heads = self.total_num_heads // tp_world_size

# Create the alibi slopes and slice them.
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(self.total_num_heads,
self.alibi_bias_max)
alibi_slopes = alibi_slopes[head_start:head_end].tolist()

self.head_dim = self.d_model // self.total_num_heads
scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithALiBi(self.num_heads, self.head_dim,
scaling, alibi_slopes)

def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
del position_ids # unused.
qkv, _ = self.qkv_proj(hidden_states)
if self.clip_qkv is not None:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
q, k, v = qkv.chunk(chunks=3, dim=-1)
if self.qk_ln:
q = self.q_ln(q)
k = self.k_ln(k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
cache_event)
output, _ = self.out_proj(attn_output)
return output


class MPTMLP(nn.Module):

def __init__(self, config: MPTConfig):
super().__init__()
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
x = self.act(x)
x, _ = self.down_proj(x)
return x


class MPTBlock(nn.Module):

def __init__(self, config: MPTConfig):
super().__init__()
hidden_size = config.d_model
self.norm_1 = nn.LayerNorm(hidden_size)
self.attn = MPTAttention(config)
self.norm_2 = nn.LayerNorm(hidden_size)
self.ffn = MPTMLP(config)

def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
x = self.norm_1(hidden_states)
x = self.attn(
position_ids=position_ids,
hidden_states=x,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = hidden_states + x
x = self.norm_2(hidden_states)
x = self.ffn(x)
hidden_states = hidden_states + x
return hidden_states


class MPTModel(nn.Module):

def __init__(self, config: MPTConfig):
super().__init__()
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"

self.wte = VocabParallelEmbedding(config.vocab_size,
config.d_model,
perform_initialization=False)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)
if config.no_bias:
for module in self.modules():
if hasattr(module, "bias"):
if isinstance(module.bias, nn.Parameter):
# Remove the bias term in Linear and LayerNorm.
module.register_parameter("bias", None)

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.blocks)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
block = self.blocks[i]
hidden_states = block(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm_f(hidden_states)
return hidden_states


class MPTForCausalLM(nn.Module):

def __init__(self, config: MPTConfig):
super().__init__()
self.config = config
assert config.tie_word_embeddings

self.transformer = MPTModel(config)
# TODO(zhuohan): create a new weight after implementing pipeline
# parallelism
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> Dict[int, SequenceOutputs]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
return next_tokens

_column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"]
_row_parallel_weights = ["out_proj.weight", "down_proj.weight"]

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):
tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, use_np_cache):
if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size].
# When tensor model parallelism is used, we need to shard
# the weight along the hidden dimension.
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads

if name.endswith(".weight"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size, hidden_size)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)
elif name.endswith(".bias"):
loaded_weight = loaded_weight.view(3, total_num_heads,
head_size)
loaded_weight = loaded_weight[:, head_start:head_end, :]
loaded_weight = loaded_weight.reshape(-1)
else:
raise ValueError(f"Unexpected parameter name {name}")
name = name.replace("Wqkv", "qkv_proj")
param = state_dict[name]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights, tp_rank)
15 changes: 15 additions & 0 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from transformers import AutoConfig, PretrainedConfig

from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import

_CONFIG_REGISTRY = {
"mpt": MPTConfig,
}


def get_config(model: str) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)
return config
5 changes: 5 additions & 0 deletions vllm/transformers_utils/configs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.transformers_utils.configs.mpt import MPTConfig

__all__ = [
"MPTConfig",
]
Loading