Skip to content

Commit

Permalink
Add support for phi-3-vision-128k-instruct (#36)
Browse files Browse the repository at this point in the history
* add phi3_v 

* Update test_models.py

* rebase branch

* remove debug print

* add prompt format

* add condition to fix quantisation

* bump version

---------

Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
  • Loading branch information
JosefAlbers and Blaizzy authored Jun 24, 2024
1 parent 7798682 commit 8ca2a55
Show file tree
Hide file tree
Showing 9 changed files with 821 additions and 6 deletions.
8 changes: 8 additions & 0 deletions mlx_vlm/models/phi3_v/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .phi3_v import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
VisionModel,
)
19 changes: 19 additions & 0 deletions mlx_vlm/models/phi3_v/language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import inspect
from dataclasses import dataclass


@dataclass
class TextConfig:
@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)


class LanguageModel:
pass
234 changes: 234 additions & 0 deletions mlx_vlm/models/phi3_v/phi3_v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import inspect
import math
from dataclasses import dataclass
from types import SimpleNamespace
from typing import Dict, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
import numpy as np

from .language import LanguageModel, TextConfig
from .su_rope import Phi3SuScaledRotaryEmbedding
from .vision import VisionConfig, VisionModel


@dataclass
class ModelConfig:
text_config: TextConfig
vision_config: VisionConfig
model_type: str
vocab_size: int

num_hidden_layers: int
intermediate_size: int
num_attention_heads: int
rms_norm_eps: float

ignore_index: int = -100
image_token_index: int = 257152
hidden_size: int = 2048
pad_token_id: int = 0

num_key_value_heads: int = None
rope_theta: float = 10000
rope_traditional: bool = False
rope_scaling: Optional[Dict[str, Union[float, str]]] = None
max_position_embeddings: int = 131072
original_max_position_embeddings: int = 4096

@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)


class Attention(nn.Module):
def __init__(self, args: TextConfig):
super().__init__()

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.num_hidden_layers = args.num_hidden_layers

self.head_dim = head_dim = args.hidden_size // n_heads
self.scale = head_dim**-0.5

op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim)
self.qkv_proj = nn.Linear(dim, op_size, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False)

rope_scale = 1.0
if args.rope_scaling and args.rope_scaling["type"] == "su":
self.rope = Phi3SuScaledRotaryEmbedding(
head_dim,
traditional=False,
base=args.rope_theta,
scale=rope_scale,
max_position_embeddings=args.max_position_embeddings,
original_max_position_embeddings=args.original_max_position_embeddings,
short_factor=args.rope_scaling["short_factor"],
long_factor=args.rope_scaling["long_factor"],
)
else:
if args.rope_scaling and args.rope_scaling["type"] == "linear":
rope_scale = 1 / args.rope_scaling["factor"]
self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
scale=rope_scale,
)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
B, L, D = x.shape

qkv = self.qkv_proj(x)
query_pos = self.n_heads * self.head_dim
queries, keys, values = mx.split(
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1
)

queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if cache is not None:
offset = cache[0].shape[2]
queries = self.rope(queries, offset=offset)
keys = self.rope(keys, offset=offset)
keys = mx.concatenate([cache[0], keys], axis=2)
values = mx.concatenate([cache[1], values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)

output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)


class MLP(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, dim, bias=False)

def __call__(self, x) -> mx.array:
x = self.gate_up_proj(x)
gate, x = mx.split(x, 2, axis=-1)
return self.down_proj(nn.silu(gate) * x)


class TransformerBlock(nn.Module):
def __init__(self, args: TextConfig):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache


class Phi3V(nn.Module):
def __init__(self, args: TextConfig):
super().__init__()
self.args = args
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.vision_embed_tokens = VisionModel(args)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)

def __call__(
self,
inputs: mx.array,
pixel_values=None,
image_sizes=None,
cache=None,
):
# print('inputs', inputs) # debug
h = self.embed_tokens(inputs)
p = np.argwhere(inputs < 0).tolist()
if pixel_values is not None:
h = self.vision_embed_tokens(pixel_values, h, image_sizes, p)
mask = None
if h.shape[1] > 1:
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
mask = mask.astype(h.dtype)
if cache is None:
cache = [None] * len(self.layers)
for i, layer in enumerate(self.layers):
h, cache[i] = layer(h, mask, cache[i])
return self.norm(h), cache


class Model(nn.Module):
def __init__(self, args: TextConfig):
super().__init__()
self.model_type = args.model_type
self.model = Phi3V(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.config = args

def __call__(
self,
inputs: mx.array,
pixel_values=None,
mask=None,
cache=None,
):
out, cache = self.model(inputs, pixel_values, mask, cache)
return self.lm_head(out).astype(self.lm_head.weight.dtype), cache

@property
def layers(self):
return self.model.layers

@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads

@property
def n_kv_heads(self):
return self.args.num_key_value_heads

@property
def language_model(self):
return self

@property
def vision_model(self):
return self.model.vision_embed_tokens
71 changes: 71 additions & 0 deletions mlx_vlm/models/phi3_v/su_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import math

import mlx.core as mx


class Phi3SuScaledRotaryEmbedding:
def __init__(
self,
dims: int,
traditional: bool = False,
base: float = 10000.0,
scale: float = 1.0,
max_position_embeddings: int = 131072,
original_max_position_embeddings: int = 4096,
short_factor: list[float] | float = 1.0,
long_factor: list[float] | float = 1.0,
):
"""
Phi3Su Scaled Rotary Embedding layer for Phi-3 models.
Args:
dims (int): The feature dimensions to be rotated.
traditional (bool, optional): Unused. Default: ``False``.
base (int, optional): Base for the exponential scaling.
scale (float, optional): The scale used to scale the positions. Default: 1.0.
max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072.
original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096.
short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0.
long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0.
"""
self.inv_freq_short = 1.0 / (
mx.array(short_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.inv_freq_long = 1.0 / (
scale
* mx.array(long_factor, dtype=mx.float32)
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims)
)
self.original_max_position_embeddings = original_max_position_embeddings
self.scaling_factor = math.sqrt(
1
+ math.log(max_position_embeddings / original_max_position_embeddings)
/ math.log(original_max_position_embeddings)
)

def _get_cos_sin(self, offset, L):
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None]
inv_freq = (
self.inv_freq_long
if position_ids.max() + 1 > self.original_max_position_embeddings
else self.inv_freq_short
)
inv_freq_expanded = mx.repeat(
inv_freq[None, :, None], position_ids.shape[0], axis=0
)
position_ids_expanded = position_ids[:, None, :]
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1)
emb = mx.concatenate([freqs, freqs], axis=-1)
cos = mx.cos(emb) * self.scaling_factor
sin = mx.sin(emb) * self.scaling_factor
return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1)

def __call__(self, x, offset: int = 0):
def _rotate_half(_x):
midpoint = _x.shape[-1] // 2
x1, x2 = _x[..., :midpoint], _x[..., midpoint:]
return mx.concatenate([-x2, x1], axis=-1)

cos, sin = self._get_cos_sin(offset, x.shape[2])
return (x * cos) + (_rotate_half(x) * sin)
Loading

0 comments on commit 8ca2a55

Please sign in to comment.