Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Llava-Next (v1.6) #43

Merged
merged 2 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
add llava-next
  • Loading branch information
Blaizzy committed Jun 22, 2024
commit 69344e79c273de4532ecaab9409294897670fd90
8 changes: 8 additions & 0 deletions mlx_vlm/models/llava_next/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .llava_next import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
VisionModel,
)
215 changes: 215 additions & 0 deletions mlx_vlm/models/llava_next/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import inspect
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn


@dataclass
class TextConfig:
model_type: str
hidden_size: int = 4096
num_hidden_layers: int = 32
intermediate_size: int = 14336
num_attention_heads: int = 32
rms_norm_eps: float = 1e-05
vocab_size: int = 32064
num_key_value_heads: int = 8
rope_theta: float = 1000000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None

@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)

def __post_init__(self):
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads

if self.rope_scaling:
required_keys = {"factor", "type"}
if not all(key in self.rope_scaling for key in required_keys):
raise ValueError(f"rope_scaling must contain keys {required_keys}")

if self.rope_scaling["type"] != "linear":
raise ValueError("rope_scaling 'type' currently only supports 'linear'")


class Attention(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()

dim = config.hidden_size
self.n_heads = n_heads = config.num_attention_heads
self.n_kv_heads = n_kv_heads = config.num_key_value_heads

self.repeats = n_heads // n_kv_heads

head_dim = config.hidden_size // n_heads
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

rope_scale = (
1 / config.rope_scaling["factor"]
if config.rope_scaling is not None
and config.rope_scaling["type"] == "linear"
else 1
)
self.rope = nn.RoPE(
head_dim,
traditional=config.rope_traditional,
base=config.rope_theta,
scale=rope_scale,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)


class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
self.up_proj = nn.Linear(dim, hidden_dim, bias=False)

def __call__(self, x) -> mx.array:
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))


class TransformerBlock(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.config = config

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache


class Llama(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
):
# for passing merged input embeddings
if inputs_embeds is None:
h = self.embed_tokens(inputs)
else:
h = inputs_embeds

mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)

if cache is None:
cache = [None] * len(self.layers)

for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])

return self.norm(h), cache


class LanguageModel(nn.Module):
def __init__(self, config: TextConfig):
super().__init__()
self.model_type = config.model_type
if self.model_type not in ["mistral", "llama"]:
raise ValueError(
f"Model type {self.model_type} not supported. Currently only 'llama' is supported"
)
self.model = Llama(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

def __call__(
self,
inputs: mx.array,
cache=None,
inputs_embeds=None,
mask: Optional[mx.array] = None,
):
out, cache = self.model(inputs, cache, inputs_embeds)
return self.lm_head(out), cache

@staticmethod
def sanitize(weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

@property
def layers(self):
return self.model.layers
Loading