Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible to use with a VL model like LLAVA? #16

Open
aliencaocao opened this issue Mar 30, 2024 · 2 comments
Open

Possible to use with a VL model like LLAVA? #16

aliencaocao opened this issue Mar 30, 2024 · 2 comments

Comments

@aliencaocao
Copy link

I am trying to use this project with a vision-language model like https://huggingface.co/docs/transformers/en/model_doc/llava_next but currently this repo does not support vision part of the model. I have a separate script that works by just splitting the vision tower and compile them separately. Do you think it will be possible to do the same using your project? My separate script is not fully using gptfast yet especially the int8 part so I really wanted to use your awesome work here.

I am using https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf specifically.

@aliencaocao
Copy link
Author

aliencaocao commented Mar 30, 2024

Based on my script here it should be quite out-of-the-box to compile and run it, and I do get about 4x speed up:

import os
from contextlib import contextmanager
from functools import partial
from time import perf_counter
from typing import Optional


@contextmanager
def catchtime(s) -> float:
    start = perf_counter()
    yield lambda: perf_counter() - start
    print(f'Time of {s=}: {perf_counter() - start:.3f} seconds')


import requests
import torch
from PIL import Image
from tqdm import tqdm

from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor, StaticCache

# noinspection PyProtectedMember
torch._inductor.config.coordinate_descent_tuning = True
# noinspection PyProtectedMember
torch._inductor.config.triton.unique_kernel_names = True
# noinspection PyProtectedMember
torch._inductor.config.fx_graph_cache = True

os.environ["TOKENIZERS_PARALLELISM"] = "true"

MODEL_NAME = 'models/llava-v1.6-mistral-7b-hf'


def mem(): return torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024, torch.cuda.memory_allocated() / 1024 / 1024 / 1024


assert torch.cuda.is_available()
device = "cuda"



def multinomial_sample_one_no_sync(probs_sort):  # Does multinomial sampling without a cuda synchronization
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L40C1-L42C82"""
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)


def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L44C1-L52C17"""
    logits = logits / max(temperature, 1e-5)
    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs


def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L54C1-L57C27"""
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs


def decode_one_tokens(model, cur_token, cache_position):
    """Copied from https://github.com/pytorch-labs/gpt-fast/blob/f6973170327003c6b1ce7edb5c015b4fa0097e6d/generate.py#L64C1-L68C45"""
    logits = model(cur_token, cache_position=cache_position, return_dict=False, use_cache=True)[0]
    new_token = sample(logits, temperature=0)[0]
    return new_token


decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True)


def gen(model, inputs, iters=100):
    print(inputs['input_ids'].shape, inputs['image_sizes'])
    generated_ids = torch.zeros((1, iters), dtype=torch.int, device=device)
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=False):
        output = model(**inputs)
        seq_len, logits = output.loss, output.logits
        cache_position = torch.tensor([seq_len], device=device)
        input_id = sample(logits, temperature=0)[0]
        generated_ids[:, 0] = input_id[:, 0]
    gen_pos = torch.tensor([1], device=device)
    print('post-1st  ', mem())
    with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
        for i in tqdm(range(iters - 1)):
            input_id = decode_one_tokens(model.language_model, input_id.clone(), cache_position)
            generated_ids.index_copy_(1, gen_pos, input_id)
            cache_position += 1
            gen_pos += 1
    print('post-last ', mem())
    return generated_ids


with torch.inference_mode():
    processor = LlavaNextProcessor.from_pretrained(MODEL_NAME)
    url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
    image = Image.open(requests.get(url, stream=True).raw)
    prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"

    torch.cuda.memory._record_memory_history()
    print('pre-model', mem())
    model = LlavaNextForConditionalGeneration.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(device)
    print('pre-cache', mem())
    static_cache = partial(StaticCache, dtype=torch.float16)
    model.language_model._setup_cache(static_cache, max_batch_size=1, max_cache_len=4096)
    print('pre-comp ', mem())
    model.language_model.compile(mode='reduce-overhead', fullgraph=True)
    model.vision_tower.compile(mode='reduce-overhead', fullgraph=True)
    print('pre-proc ', mem())
    inputs = processor(prompt, image, return_tensors="pt").to(device)

    print('pre-gen1 ', mem())
    with catchtime('first compile gen:'):
        out = gen(model, inputs, iters=10)
        # print(out)
        print(processor.decode(out[0], skip_special_tokens=True))
    print('pre-gen2 ', mem())
    with catchtime('second compile gen:'):
        out = gen(model, inputs, iters=100)
        # print(out)
        print(processor.decode(out[0], skip_special_tokens=True))
    # torch.cuda.memory._dump_snapshot("snapshot_full.pickle")

But the issue here is compiling a full fp16 requires more than 16GB vram which is more than what I have for production.

@MDK8888
Copy link
Owner

MDK8888 commented Apr 2, 2024

Hey, apologies for the late response! I will look into this and get back to you soon :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants