Skip to content

Commit

Permalink
Refactor KVCache and add stream generate (#53)
Browse files Browse the repository at this point in the history
* refactor KVcache and add stream generate

* update KV Cache

* fix tests and update generate

* temporary fix to deepseek and paligema

* refactor chat_template and fix idefics KV cache

* remove tuple and update model
  • Loading branch information
Blaizzy authored Jul 26, 2024
1 parent 9421149 commit 6c98971
Show file tree
Hide file tree
Showing 20 changed files with 548 additions and 452 deletions.
55 changes: 16 additions & 39 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@

import mlx.core as mx

from .prompt_utils import get_message_json
from .prompt_utils import apply_chat_template
from .utils import generate, get_model_path, load, load_config, load_image_processor

MODEL_TYPE = ""
DEFAULT_MODEL_PATH = "mlx-community/nanoLLaVA-1.5-8bit"
DEFAULT_IMAGE = "http://images.cocodataset.org/val2017/000000039769.jpg"
DEFAULT_PROMPT = "What are these?"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.5
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0


def parse_arguments():
Expand All @@ -16,36 +22,31 @@ def parse_arguments():
parser.add_argument(
"--model",
type=str,
default="qnguyen3/nanoLLaVA",
default=DEFAULT_MODEL_PATH,
help="The path to the local model directory or Hugging Face repo.",
)
parser.add_argument(
"--image",
type=str,
default="http://images.cocodataset.org/val2017/000000039769.jpg",
default=DEFAULT_IMAGE,
help="URL or path of the image to process.",
)
parser.add_argument(
"--prompt",
type=str,
default="What are these?",
default=DEFAULT_PROMPT,
help="Message to be processed by the model.",
)
parser.add_argument(
"--max-tokens",
type=int,
default=100,
default=DEFAULT_MAX_TOKENS,
help="Maximum number of tokens to generate.",
)
parser.add_argument(
"--temp", type=float, default=0.3, help="Temperature for sampling."
)
parser.add_argument(
"--verbose",
type=bool,
help="Detailed output.",
default=True,
"--temp", type=float, default=DEFAULT_TEMP, help="Temperature for sampling."
)
parser.add_argument("--verbose", action="store_false", help="Detailed output.")
return parser.parse_args()


Expand All @@ -57,38 +58,14 @@ def get_model_and_processors(model_path):
return model, processor, image_processor, config


def sample(logits, temperature=0.0):
if temperature == 0:
return mx.argmax(logits, axis=-1)
else:
return mx.random.categorical(logits * (1 / temperature))


def main():
args = parse_arguments()
model, processor, image_processor, config = get_model_and_processors(args.model)

prompt = codecs.decode(args.prompt, "unicode_escape")

if "chat_template" in processor.__dict__.keys():
prompt = processor.apply_chat_template(
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)

elif "tokenizer" in processor.__dict__.keys():
if model.config.model_type != "paligemma":
prompt = processor.tokenizer.apply_chat_template(
[get_message_json(config["model_type"], prompt)],
tokenize=False,
add_generation_prompt=True,
)

else:
ValueError(
"Error: processor does not have 'chat_template' or 'tokenizer' attribute."
)
if model.config.model_type != "paligemma":
prompt = apply_chat_template(processor, config, prompt)

output = generate(
model,
Expand Down
39 changes: 39 additions & 0 deletions mlx_vlm/models/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Dict

import mlx.core as mx
from PIL import Image
from transformers.image_processing_utils import get_size_dict
from transformers.image_utils import ChannelDimension, PILImageResampling
Expand Down Expand Up @@ -49,3 +50,41 @@ def __init__(
@abstractmethod
def preprocess(self, images):
pass


class KVCache:

def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys = None
self.values = None
self.offset = 0
self.step = 256

def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (1, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (1, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v

self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
4 changes: 2 additions & 2 deletions mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ def __call__(
self, input_ids: mx.array, pixel_values: mx.array, mask: mx.array, cache=None
):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits, cache = self.language_model(
logits = self.language_model(
inputs=input_ids, cache=cache, inputs_embeds=input_embeddings
)
return logits, cache
return logits

@staticmethod
def from_pretrained(path_or_hf_repo: str):
Expand Down
76 changes: 41 additions & 35 deletions mlx_vlm/models/idefics2/language.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def __post_init__(self):


class Attention(nn.Module):
def __init__(self, args: TextConfig):
def __init__(self, config: TextConfig):
super().__init__()

dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
dim = config.hidden_size
self.n_heads = n_heads = config.num_attention_heads
self.n_kv_heads = n_kv_heads = config.num_key_value_heads

head_dim = args.hidden_size // n_heads
head_dim = config.hidden_size // n_heads
self.scale = head_dim**-0.5

self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False)
Expand All @@ -54,8 +54,8 @@ def __init__(self, args: TextConfig):

self.rope = nn.RoPE(
head_dim,
traditional=args.rope_traditional,
base=args.rope_theta,
traditional=config.rope_traditional,
base=config.rope_theta,
)

def __call__(
Expand All @@ -74,11 +74,9 @@ def __call__(
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
queries = self.rope(queries, offset=cache.offset)
keys = self.rope(keys, offset=cache.offset)
keys, values = cache.update_and_fetch(keys, values)
else:
queries = self.rope(queries)
keys = self.rope(keys)
Expand All @@ -87,7 +85,7 @@ def __call__(
queries, keys, values, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys, values)
return self.o_proj(output)


class MLP(nn.Module):
Expand All @@ -102,45 +100,45 @@ def __call__(self, x) -> mx.array:


class TransformerBlock(nn.Module):
def __init__(self, args: TextConfig):
def __init__(self, config: TextConfig):
super().__init__()
self.num_attention_heads = args.num_attention_heads
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
config.hidden_size, eps=config.rms_norm_eps
)
self.args = args
self.config = config

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
r, cache = self.self_attn(self.input_layernorm(x), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, cache
return out


class LanguageModel(nn.Module):
def __init__(self, args: TextConfig):
def __init__(self, config: TextConfig):
super().__init__()
self.args = args
self.model_type = args.model_type
self.vocab_size = args.vocab_size
self.num_hidden_layers = args.num_hidden_layers
self.config = config
self.model_type = config.model_type
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
assert self.vocab_size > 0
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

def __call__(
self,
Expand All @@ -163,10 +161,10 @@ def __call__(
if cache is None:
cache = [None] * len(self.layers)

for e, layer in enumerate(self.layers):
h, cache[e] = layer(h, mask, cache[e])
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)

return self.lm_head(self.norm(h)), cache
return self.lm_head(self.norm(h))

def sanitize(self, weights):
# Remove unused precomputed rotary freqs
Expand All @@ -177,3 +175,11 @@ def sanitize(self, weights):
@property
def layers(self):
return self.model.layers

@property
def head_dim(self):
return self.config.hidden_size // self.config.num_attention_heads

@property
def n_kv_heads(self):
return self.config.num_key_value_heads
61 changes: 0 additions & 61 deletions mlx_vlm/models/llava/README.md

This file was deleted.

Loading

0 comments on commit 6c98971

Please sign in to comment.