-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible to use with a VL model like LLAVA? #16
Comments
Based on my script here it should be quite out-of-the-box to compile and run it, and I do get about 4x speed up: import os
from contextlib import contextmanager
from functools import partial
from time import perf_counter
from typing import Optional
@contextmanager
def catchtime(s) -> float:
start = perf_counter()
yield lambda: perf_counter() - start
print(f'Time of {s=}: {perf_counter() - start:.3f} seconds')
import requests
import torch
from PIL import Image
from tqdm import tqdm
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache
# noinspection PyProtectedMember
torch._inductor.config.coordinate_descent_tuning = True
# noinspection PyProtectedMember
torch._inductor.config.triton.unique_kernel_names = True
# noinspection PyProtectedMember
torch._inductor.config.fx_graph_cache = True
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MODEL_NAME = 'models/llava-v1.6-mistral-7b-hf'
def mem(): return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.memory_allocated() / 1024 / 1024 / 1024
assert torch.cuda.is_available()
device = "cuda"
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L40C1-L42C82"""
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L44C1-L52C17"""
logits = logits / max(temperature, 1e-5)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v.select(-1, -1).unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L54C1-L57C27"""
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def decode_one_tokens(model, cur_token, cache_position):
"""Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L64C1-L68C45"""
logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache=True)[0]
new_token = sample(logits, temperature=0)[0]
return new_token
decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)
def gen(model, inputs, iters=100):
print(inputs['input_ids'].shape, inputs['image_sizes'])
generated_ids = torch.zeros((1, iters), dtype=torch.int, device=device)
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False):
output = model(**inputs)
seq_len, logits = output.loss, output.logits
cache_position = torch.tensor([seq_len], device=device)
input_id = sample(logits, temperature=0)[0]
generated_ids[:, 0] = input_id[:, 0]
gen_pos = torch.tensor([1], device=device)
print('post-1st ', mem())
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
for i in tqdm(range(iters - 1)):
input_id = decode_one_tokens(model.language_model, input_id.clone(), cache_position)
generated_ids.index_copy_(1, gen_pos, input_id)
cache_position += 1
gen_pos += 1
print('post-last ', mem())
return generated_ids
with torch.inference_mode():
processor = LlavaNextProcessor.from_pretrained(MODEL_NAME)
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
torch.cuda.memory._record_memory_history()
print('pre-model', mem())
model = LlavaNextForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
print('pre-cache', mem())
static_cache = partial(StaticCache, dtype=torch.float16)
model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
print('pre-comp ', mem())
model.language_model.compile(mode='reduce-overhead', fullgraph=True)
model.vision_tower.compile(mode='reduce-overhead', fullgraph=True)
print('pre-proc ', mem())
inputs = processor(prompt, image, return_tensors="pt").to(device)
print('pre-gen1 ', mem())
with catchtime('first compile gen:'):
out = gen(model, inputs, iters=10)
# print(out)
print(processor.decode(out[0], skip_special_tokens=True))
print('pre-gen2 ', mem())
with catchtime('second compile gen:'):
out = gen(model, inputs, iters=100)
# print(out)
print(processor.decode(out[0], skip_special_tokens=True))
# torch.cuda.memory._dump_snapshot("snapshot_full.pickle") But the issue here is compiling a full fp16 requires more than 16GB vram which is more than what I have for production. |
Hey, apologies for the late response! I will look into this and get back to you soon :) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I am trying to use this project with a vision-language model like https://huggingface.co/docs/transformers/en/model_doc/llava_next but currently this repo does not support vision part of the model. I have a separate script that works by just splitting the vision tower and compile them separately. Do you think it will be possible to do the same using your project? My separate script is not fully using gptfast yet especially the int8 part so I really wanted to use your awesome work here.
I am using https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf specifically.
The text was updated successfully, but these errors were encountered: