-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for Llava-Next (v1.6) (#43)
* add llava-next * add tests
- Loading branch information
Showing
6 changed files
with
718 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .llava_next import ( | ||
LanguageModel, | ||
Model, | ||
ModelConfig, | ||
TextConfig, | ||
VisionConfig, | ||
VisionModel, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,215 @@ | ||
import inspect | ||
from dataclasses import dataclass | ||
from typing import Dict, Optional, Tuple, Union | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
|
||
|
||
@dataclass | ||
class TextConfig: | ||
model_type: str | ||
hidden_size: int = 4096 | ||
num_hidden_layers: int = 32 | ||
intermediate_size: int = 14336 | ||
num_attention_heads: int = 32 | ||
rms_norm_eps: float = 1e-05 | ||
vocab_size: int = 32064 | ||
num_key_value_heads: int = 8 | ||
rope_theta: float = 1000000 | ||
rope_traditional: bool = False | ||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None | ||
|
||
@classmethod | ||
def from_dict(cls, params): | ||
return cls( | ||
**{ | ||
k: v | ||
for k, v in params.items() | ||
if k in inspect.signature(cls).parameters | ||
} | ||
) | ||
|
||
def __post_init__(self): | ||
if self.num_key_value_heads is None: | ||
self.num_key_value_heads = self.num_attention_heads | ||
|
||
if self.rope_scaling: | ||
required_keys = {"factor", "type"} | ||
if not all(key in self.rope_scaling for key in required_keys): | ||
raise ValueError(f"rope_scaling must contain keys {required_keys}") | ||
|
||
if self.rope_scaling["type"] != "linear": | ||
raise ValueError("rope_scaling 'type' currently only supports 'linear'") | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, config: TextConfig): | ||
super().__init__() | ||
|
||
dim = config.hidden_size | ||
self.n_heads = n_heads = config.num_attention_heads | ||
self.n_kv_heads = n_kv_heads = config.num_key_value_heads | ||
|
||
self.repeats = n_heads // n_kv_heads | ||
|
||
head_dim = config.hidden_size // n_heads | ||
self.scale = head_dim**-0.5 | ||
|
||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) | ||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | ||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) | ||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) | ||
|
||
rope_scale = ( | ||
1 / config.rope_scaling["factor"] | ||
if config.rope_scaling is not None | ||
and config.rope_scaling["type"] == "linear" | ||
else 1 | ||
) | ||
self.rope = nn.RoPE( | ||
head_dim, | ||
traditional=config.rope_traditional, | ||
base=config.rope_theta, | ||
scale=rope_scale, | ||
) | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
B, L, D = x.shape | ||
|
||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) | ||
|
||
# Prepare the queries, keys and values for the attention computation | ||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | ||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
|
||
if cache is not None: | ||
key_cache, value_cache = cache | ||
queries = self.rope(queries, offset=key_cache.shape[2]) | ||
keys = self.rope(keys, offset=key_cache.shape[2]) | ||
keys = mx.concatenate([key_cache, keys], axis=2) | ||
values = mx.concatenate([value_cache, values], axis=2) | ||
else: | ||
queries = self.rope(queries) | ||
keys = self.rope(keys) | ||
|
||
output = mx.fast.scaled_dot_product_attention( | ||
queries, keys, values, scale=self.scale, mask=mask | ||
) | ||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||
return self.o_proj(output), (keys, values) | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, dim, hidden_dim): | ||
super().__init__() | ||
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) | ||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) | ||
self.up_proj = nn.Linear(dim, hidden_dim, bias=False) | ||
|
||
def __call__(self, x) -> mx.array: | ||
return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, config: TextConfig): | ||
super().__init__() | ||
self.num_attention_heads = config.num_attention_heads | ||
self.hidden_size = config.hidden_size | ||
self.self_attn = Attention(config) | ||
self.mlp = MLP(config.hidden_size, config.intermediate_size) | ||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
self.post_attention_layernorm = nn.RMSNorm( | ||
config.hidden_size, eps=config.rms_norm_eps | ||
) | ||
self.config = config | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache) | ||
h = x + r | ||
r = self.mlp(self.post_attention_layernorm(h)) | ||
out = h + r | ||
return out, cache | ||
|
||
|
||
class Llama(nn.Module): | ||
def __init__(self, config: TextConfig): | ||
super().__init__() | ||
self.config = config | ||
self.vocab_size = config.vocab_size | ||
self.num_hidden_layers = config.num_hidden_layers | ||
assert self.vocab_size > 0 | ||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) | ||
self.layers = [ | ||
TransformerBlock(config=config) for _ in range(config.num_hidden_layers) | ||
] | ||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
cache=None, | ||
inputs_embeds=None, | ||
): | ||
# for passing merged input embeddings | ||
if inputs_embeds is None: | ||
h = self.embed_tokens(inputs) | ||
else: | ||
h = inputs_embeds | ||
|
||
mask = None | ||
if h.shape[1] > 1: | ||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) | ||
mask = mask.astype(h.dtype) | ||
|
||
if cache is None: | ||
cache = [None] * len(self.layers) | ||
|
||
for e, layer in enumerate(self.layers): | ||
h, cache[e] = layer(h, mask, cache[e]) | ||
|
||
return self.norm(h), cache | ||
|
||
|
||
class LanguageModel(nn.Module): | ||
def __init__(self, config: TextConfig): | ||
super().__init__() | ||
self.model_type = config.model_type | ||
if self.model_type not in ["mistral", "llama"]: | ||
raise ValueError( | ||
f"Model type {self.model_type} not supported. Currently only 'llama' is supported" | ||
) | ||
self.model = Llama(config) | ||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
cache=None, | ||
inputs_embeds=None, | ||
mask: Optional[mx.array] = None, | ||
): | ||
out, cache = self.model(inputs, cache, inputs_embeds) | ||
return self.lm_head(out), cache | ||
|
||
@staticmethod | ||
def sanitize(weights): | ||
# Remove unused precomputed rotary freqs | ||
return { | ||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k | ||
} | ||
|
||
@property | ||
def layers(self): | ||
return self.model.layers |
Oops, something went wrong.