Skip to content

Commit

Permalink
[Model] Add moondream vision language model
Browse files Browse the repository at this point in the history
  • Loading branch information
vikhyat committed Apr 20, 2024
1 parent cc74b2b commit 355a6ea
Show file tree
Hide file tree
Showing 6 changed files with 331 additions and 10 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
- Moondream (`vikhyatk/moondream2`, `vikhyatk/moondream1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
Expand Down
49 changes: 49 additions & 0 deletions examples/moondream_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch
from torchvision.transforms.v2 import (
Compose,
Resize,
InterpolationMode,
ToImage,
ToDtype,
Normalize,
)
from PIL import Image

from vllm import LLM, SamplingParams
from vllm.sequence import MultiModalData

if __name__ == "__main__":

sampling_params = SamplingParams(temperature=0, max_tokens=256)
llm = LLM(
model="vikhyatk/moondream2",
trust_remote_code=True,
image_input_type="pixel_values",
image_token_id=50256,
image_input_shape="1,3,378,378",
image_feature_size=729,
)

preprocess = Compose(
[
Resize(size=(378, 378), interpolation=InterpolationMode.BICUBIC),
ToImage(),
ToDtype(torch.float32, scale=True),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)

image = Image.open("docs/source/assets/kernel/value.png").convert("RGB")
image_pixels = preprocess(image).unsqueeze(0)

outputs = llm.generate(
[("<|endoftext|>" * 729) + "\n\nQuestion: Describe this image.\n\nAnswer:"],
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=image_pixels
),
sampling_params=sampling_params,
)

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
5 changes: 2 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.models.moondream import Moondream

if TYPE_CHECKING:
from vllm.model_executor.layers.linear import LinearMethodBase

_VISION_MODEL_CLASSES = [
LlavaForConditionalGeneration,
]
_VISION_MODEL_CLASSES = [LlavaForConditionalGeneration, Moondream]

logger = init_logger(__name__)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
"Moondream": ("moondream", "Moondream"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
Expand Down
266 changes: 266 additions & 0 deletions vllm/model_executor/models/moondream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
from typing import Optional, Iterable, Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.attention import AttentionMetadata
from vllm.config import VisionLanguageConfig
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.models.phi import PhiForCausalLM
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput


class Attention(nn.Module):
def __init__(self, dim, num_heads=16):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"

self.num_heads = num_heads
self.head_dim = dim // num_heads

self.qkv = nn.Linear(dim, dim * 3)
self.proj = nn.Linear(dim, dim)

def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, self.head_dim)
.permute(2, 0, 3, 1, 4)
)
q, k, v = qkv.unbind(0)

# TODO: Replace with VLLM attention implementation after adding support
# for acasual attention.
x = F.scaled_dot_product_attention(q, k, v)

x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x


class VitBlock(nn.Module):
def __init__(self, embed_dim):
super().__init__()
self.attn = Attention(embed_dim)
self.mlp = MLP(embed_dim, 4304)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)

def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x


class VisionTransformer(nn.Module):
def __init__(self):
super().__init__()

embed_len = 729
embed_dim = 1152

self.patch_embed = LinearPatchEmbedding()
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02)
self.blocks = nn.Sequential(*[VitBlock(embed_dim) for _ in range(27)])
self.norm = nn.LayerNorm(embed_dim)

def forward(self, x):
x = self.patch_embed(x)
x = x + self.pos_embed
for block in self.blocks:
x = block(x)
return self.norm(x)


class EncoderWrapper(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.ModuleDict({"visual": VisionTransformer()})

def forward(self, x):
return self.model["visual"](x)


class LinearPatchEmbedding(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(588, 1152)

def forward(self, x):
b, c, hp1, wp2 = x.shape
p1, p2 = 14, 14
h, w = hp1 // p1, wp2 // p2
x = x.reshape(b, c, h, p1, w, p2)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(b, h * w, c * p1 * p2)

return self.linear(x)


class MLP(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: int = None,
out_features: int = None,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU(approximate="tanh")
self.fc2 = nn.Linear(hidden_features, out_features)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x


class VisionProjection(nn.Module):
def __init__(self):
super().__init__()

image_embedding_dim = 1152
model_dim = 2048
hidden_dim = model_dim * 4

self.mlp = MLP(image_embedding_dim, hidden_dim, model_dim)

def forward(self, x):
return self.mlp(x)


class VisionEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = EncoderWrapper()
self.projection = VisionProjection()

def forward(self, x) -> torch.Tensor:
x = self.encoder(x)
x = self.projection(x)
return x


class Moondream(nn.Module):
def __init__(
self,
config: PretrainedConfig,
vision_language_config: VisionLanguageConfig,
linear_method: Optional["LinearMethodBase"] = None,
) -> None:
super().__init__()
self.config = config

self.vision_language_config = vision_language_config

assert self.vision_language_config, (
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments."
)

if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES
):
self.vision_encoder = VisionEncoder()
else:
self.vision_encoder = None

self.linear_method = linear_method

self.text_model = PhiForCausalLM(config.text_config)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None,
) -> SamplerOutput:
if image_input is not None:
if list(image_input.shape[1:]) != list(
self.vision_language_config.image_input_shape[1:]
):
raise ValueError(
f"The expected image tensor shape is batch dimension "
f"plus "
f"{self.vision_language_config.image_input_shape[1:]}."
f" You supplied {image_input.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args."
)

if self.vision_encoder is not None:
image_features = self.vision_encoder(image_input)
else:
image_features = image_input

inputs_embeds = self.text_model.model.embed_tokens(input_ids)
mask = input_ids == self.vision_language_config.image_token_id
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
else:
inputs_embeds = None

hidden_states = self.text_model(
input_ids, positions, kv_caches, attn_metadata, inputs_embeds=inputs_embeds
)
return hidden_states

def compute_logits(
self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata
) -> torch.Tensor:
return self.text_model.compute_logits(hidden_states, sampling_metadata)

def sample(
self, logits: torch.Tensor, sampling_metadata: SamplingMetadata
) -> Optional[SamplerOutput]:
return self.text_model.sample(logits, sampling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())

params_map = {
"text_model.transformer.embd.wte.weight": "text_model.model.embed_tokens.weight",
"text_model.lm_head.linear.weight": "text_model.lm_head.weight",
"text_model.lm_head.linear.bias": "text_model.lm_head.bias",
"text_model.lm_head.ln.weight": "text_model.model.final_layernorm.weight",
"text_model.lm_head.ln.bias": "text_model.model.final_layernorm.bias",
}

for name, loaded_weight in weights:
param = None

if name in params_map:
param = params_dict[params_map[name]]
elif name in params_dict:
param = params_dict[name]
elif name.startswith("text_model."):
replacements = {
"text_model.transformer.h": "text_model.model.layers",
"ln": "input_layernorm",
"mixer.Wqkv": "self_attn.qkv_proj",
"mixer.out_proj": "self_attn.dense",
}

mp = name
for k, v in replacements.items():
if k in mp:
mp = mp.replace(k, v)
if mp in params_dict:
param = params_dict[mp]

if param is None:
raise ValueError(f"Unmapped weight: {name}")
else:
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
16 changes: 11 additions & 5 deletions vllm/model_executor/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,13 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_tokens(input_ids)

for i in range(self.config.num_hidden_layers):
layer = self.layers[i]
hidden_states = layer(
Expand Down Expand Up @@ -244,9 +249,11 @@ def forward(
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
hidden_states = self.model(
input_ids, positions, kv_caches, attn_metadata, inputs_embeds
)

return hidden_states

Expand Down Expand Up @@ -295,6 +302,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# pylint: disable=E1136

param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)

0 comments on commit 355a6ea

Please sign in to comment.