-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for phi-3-vision-128k-instruct (#36)
* add phi3_v * Update test_models.py * rebase branch * remove debug print * add prompt format * add condition to fix quantisation * bump version --------- Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
- Loading branch information
1 parent
7798682
commit 8ca2a55
Showing
9 changed files
with
821 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .phi3_v import ( | ||
LanguageModel, | ||
Model, | ||
ModelConfig, | ||
TextConfig, | ||
VisionConfig, | ||
VisionModel, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import inspect | ||
from dataclasses import dataclass | ||
|
||
|
||
@dataclass | ||
class TextConfig: | ||
@classmethod | ||
def from_dict(cls, params): | ||
return cls( | ||
**{ | ||
k: v | ||
for k, v in params.items() | ||
if k in inspect.signature(cls).parameters | ||
} | ||
) | ||
|
||
|
||
class LanguageModel: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
import inspect | ||
import math | ||
from dataclasses import dataclass | ||
from types import SimpleNamespace | ||
from typing import Dict, Optional, Tuple, Union | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
import numpy as np | ||
|
||
from .language import LanguageModel, TextConfig | ||
from .su_rope import Phi3SuScaledRotaryEmbedding | ||
from .vision import VisionConfig, VisionModel | ||
|
||
|
||
@dataclass | ||
class ModelConfig: | ||
text_config: TextConfig | ||
vision_config: VisionConfig | ||
model_type: str | ||
vocab_size: int | ||
|
||
num_hidden_layers: int | ||
intermediate_size: int | ||
num_attention_heads: int | ||
rms_norm_eps: float | ||
|
||
ignore_index: int = -100 | ||
image_token_index: int = 257152 | ||
hidden_size: int = 2048 | ||
pad_token_id: int = 0 | ||
|
||
num_key_value_heads: int = None | ||
rope_theta: float = 10000 | ||
rope_traditional: bool = False | ||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None | ||
max_position_embeddings: int = 131072 | ||
original_max_position_embeddings: int = 4096 | ||
|
||
@classmethod | ||
def from_dict(cls, params): | ||
return cls( | ||
**{ | ||
k: v | ||
for k, v in params.items() | ||
if k in inspect.signature(cls).parameters | ||
} | ||
) | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
|
||
dim = args.hidden_size | ||
self.n_heads = n_heads = args.num_attention_heads | ||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads | ||
self.num_hidden_layers = args.num_hidden_layers | ||
|
||
self.head_dim = head_dim = args.hidden_size // n_heads | ||
self.scale = head_dim**-0.5 | ||
|
||
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) | ||
self.qkv_proj = nn.Linear(dim, op_size, bias=False) | ||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) | ||
|
||
rope_scale = 1.0 | ||
if args.rope_scaling and args.rope_scaling["type"] == "su": | ||
self.rope = Phi3SuScaledRotaryEmbedding( | ||
head_dim, | ||
traditional=False, | ||
base=args.rope_theta, | ||
scale=rope_scale, | ||
max_position_embeddings=args.max_position_embeddings, | ||
original_max_position_embeddings=args.original_max_position_embeddings, | ||
short_factor=args.rope_scaling["short_factor"], | ||
long_factor=args.rope_scaling["long_factor"], | ||
) | ||
else: | ||
if args.rope_scaling and args.rope_scaling["type"] == "linear": | ||
rope_scale = 1 / args.rope_scaling["factor"] | ||
self.rope = nn.RoPE( | ||
head_dim, | ||
traditional=args.rope_traditional, | ||
base=args.rope_theta, | ||
scale=rope_scale, | ||
) | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
B, L, D = x.shape | ||
|
||
qkv = self.qkv_proj(x) | ||
query_pos = self.n_heads * self.head_dim | ||
queries, keys, values = mx.split( | ||
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 | ||
) | ||
|
||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | ||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
|
||
if cache is not None: | ||
offset = cache[0].shape[2] | ||
queries = self.rope(queries, offset=offset) | ||
keys = self.rope(keys, offset=offset) | ||
keys = mx.concatenate([cache[0], keys], axis=2) | ||
values = mx.concatenate([cache[1], values], axis=2) | ||
else: | ||
queries = self.rope(queries) | ||
keys = self.rope(keys) | ||
|
||
output = mx.fast.scaled_dot_product_attention( | ||
queries, keys, values, scale=self.scale, mask=mask | ||
) | ||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||
return self.o_proj(output), (keys, values) | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, dim, hidden_dim): | ||
super().__init__() | ||
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) | ||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) | ||
|
||
def __call__(self, x) -> mx.array: | ||
x = self.gate_up_proj(x) | ||
gate, x = mx.split(x, 2, axis=-1) | ||
return self.down_proj(nn.silu(gate) * x) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
self.num_attention_heads = args.num_attention_heads | ||
self.hidden_size = args.hidden_size | ||
self.self_attn = Attention(args) | ||
self.mlp = MLP(args.hidden_size, args.intermediate_size) | ||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | ||
self.post_attention_layernorm = nn.RMSNorm( | ||
args.hidden_size, eps=args.rms_norm_eps | ||
) | ||
self.args = args | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache) | ||
h = x + r | ||
r = self.mlp(self.post_attention_layernorm(h)) | ||
out = h + r | ||
return out, cache | ||
|
||
|
||
class Phi3V(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
self.args = args | ||
self.vocab_size = args.vocab_size | ||
self.num_hidden_layers = args.num_hidden_layers | ||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | ||
self.vision_embed_tokens = VisionModel(args) | ||
self.layers = [ | ||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) | ||
] | ||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
pixel_values=None, | ||
image_sizes=None, | ||
cache=None, | ||
): | ||
# print('inputs', inputs) # debug | ||
h = self.embed_tokens(inputs) | ||
p = np.argwhere(inputs < 0).tolist() | ||
if pixel_values is not None: | ||
h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) | ||
mask = None | ||
if h.shape[1] > 1: | ||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) | ||
mask = mask.astype(h.dtype) | ||
if cache is None: | ||
cache = [None] * len(self.layers) | ||
for i, layer in enumerate(self.layers): | ||
h, cache[i] = layer(h, mask, cache[i]) | ||
return self.norm(h), cache | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
self.model_type = args.model_type | ||
self.model = Phi3V(args) | ||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) | ||
self.config = args | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
pixel_values=None, | ||
mask=None, | ||
cache=None, | ||
): | ||
out, cache = self.model(inputs, pixel_values, mask, cache) | ||
return self.lm_head(out).astype(self.lm_head.weight.dtype), cache | ||
|
||
@property | ||
def layers(self): | ||
return self.model.layers | ||
|
||
@property | ||
def head_dim(self): | ||
return self.args.hidden_size // self.args.num_attention_heads | ||
|
||
@property | ||
def n_kv_heads(self): | ||
return self.args.num_key_value_heads | ||
|
||
@property | ||
def language_model(self): | ||
return self | ||
|
||
@property | ||
def vision_model(self): | ||
return self.model.vision_embed_tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import math | ||
|
||
import mlx.core as mx | ||
|
||
|
||
class Phi3SuScaledRotaryEmbedding: | ||
def __init__( | ||
self, | ||
dims: int, | ||
traditional: bool = False, | ||
base: float = 10000.0, | ||
scale: float = 1.0, | ||
max_position_embeddings: int = 131072, | ||
original_max_position_embeddings: int = 4096, | ||
short_factor: list[float] | float = 1.0, | ||
long_factor: list[float] | float = 1.0, | ||
): | ||
""" | ||
Phi3Su Scaled Rotary Embedding layer for Phi-3 models. | ||
Args: | ||
dims (int): The feature dimensions to be rotated. | ||
traditional (bool, optional): Unused. Default: ``False``. | ||
base (int, optional): Base for the exponential scaling. | ||
scale (float, optional): The scale used to scale the positions. Default: 1.0. | ||
max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072. | ||
original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096. | ||
short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0. | ||
long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0. | ||
""" | ||
self.inv_freq_short = 1.0 / ( | ||
mx.array(short_factor, dtype=mx.float32) | ||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||
) | ||
self.inv_freq_long = 1.0 / ( | ||
scale | ||
* mx.array(long_factor, dtype=mx.float32) | ||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||
) | ||
self.original_max_position_embeddings = original_max_position_embeddings | ||
self.scaling_factor = math.sqrt( | ||
1 | ||
+ math.log(max_position_embeddings / original_max_position_embeddings) | ||
/ math.log(original_max_position_embeddings) | ||
) | ||
|
||
def _get_cos_sin(self, offset, L): | ||
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None] | ||
inv_freq = ( | ||
self.inv_freq_long | ||
if position_ids.max() + 1 > self.original_max_position_embeddings | ||
else self.inv_freq_short | ||
) | ||
inv_freq_expanded = mx.repeat( | ||
inv_freq[None, :, None], position_ids.shape[0], axis=0 | ||
) | ||
position_ids_expanded = position_ids[:, None, :] | ||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) | ||
emb = mx.concatenate([freqs, freqs], axis=-1) | ||
cos = mx.cos(emb) * self.scaling_factor | ||
sin = mx.sin(emb) * self.scaling_factor | ||
return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) | ||
|
||
def __call__(self, x, offset: int = 0): | ||
def _rotate_half(_x): | ||
midpoint = _x.shape[-1] // 2 | ||
x1, x2 = _x[..., :midpoint], _x[..., midpoint:] | ||
return mx.concatenate([-x2, x1], axis=-1) | ||
|
||
cos, sin = self._get_cos_sin(offset, x.shape[2]) | ||
return (x * cos) + (_rotate_half(x) * sin) |
Oops, something went wrong.