forked from haotian-liu/LLaVA
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request haotian-liu#595 from yorickvP/main
Add Replicate demo and API
- Loading branch information
Showing
4 changed files
with
216 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# The .dockerignore file excludes files from the container build process. | ||
# | ||
# https://docs.docker.com/engine/reference/builder/#dockerignore-file | ||
|
||
# Exclude Git files | ||
.git | ||
.github | ||
.gitignore | ||
|
||
# Exclude Python cache files | ||
__pycache__ | ||
.mypy_cache | ||
.pytest_cache | ||
.ruff_cache | ||
|
||
# Exclude Python virtual environment | ||
/venv | ||
|
||
# Exclude some weights | ||
/openai | ||
/liuhaotian |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Configuration for Cog ⚙️ | ||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md | ||
|
||
build: | ||
gpu: true | ||
|
||
python_version: "3.11" | ||
|
||
python_packages: | ||
- "torch==2.0.1" | ||
- "accelerate==0.21.0" | ||
- "bitsandbytes==0.41.0" | ||
- "deepspeed==0.9.5" | ||
- "einops-exts==0.0.4" | ||
- "einops==0.6.1" | ||
- "gradio==3.35.2" | ||
- "gradio_client==0.2.9" | ||
- "httpx==0.24.0" | ||
- "markdown2==2.4.10" | ||
- "numpy==1.26.0" | ||
- "peft==0.4.0" | ||
- "scikit-learn==1.2.2" | ||
- "sentencepiece==0.1.99" | ||
- "shortuuid==1.0.11" | ||
- "timm==0.6.13" | ||
- "tokenizers==0.13.3" | ||
- "torch==2.0.1" | ||
- "torchvision==0.15.2" | ||
- "transformers==4.31.0" | ||
- "wandb==0.15.12" | ||
- "wavedrom==2.0.3.post3" | ||
- "Pygments==2.16.1" | ||
run: | ||
- curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget | ||
|
||
# predict.py defines how predictions are run on your model | ||
predict: "predict.py:Predictor" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import torch | ||
|
||
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN | ||
from llava.conversation import conv_templates, SeparatorStyle | ||
from llava.model.builder import load_pretrained_model | ||
from llava.utils import disable_torch_init | ||
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria | ||
from transformers.generation.streamers import TextIteratorStreamer | ||
|
||
from PIL import Image | ||
|
||
import requests | ||
from io import BytesIO | ||
|
||
from cog import BasePredictor, Input, Path, ConcatenateIterator | ||
import time | ||
import subprocess | ||
from threading import Thread | ||
|
||
import os | ||
os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights" | ||
|
||
# url for the weights mirror | ||
REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default" | ||
# files to download from the weights mirrors | ||
weights = [ | ||
{ | ||
"dest": "liuhaotian/llava-v1.5-13b", | ||
# git commit hash from huggingface | ||
"src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8", | ||
"files": [ | ||
"config.json", | ||
"generation_config.json", | ||
"pytorch_model-00001-of-00003.bin", | ||
"pytorch_model-00002-of-00003.bin", | ||
"pytorch_model-00003-of-00003.bin", | ||
"pytorch_model.bin.index.json", | ||
"special_tokens_map.json", | ||
"tokenizer.model", | ||
"tokenizer_config.json", | ||
] | ||
}, | ||
{ | ||
"dest": "openai/clip-vit-large-patch14-336", | ||
"src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1", | ||
"files": [ | ||
"config.json", | ||
"preprocessor_config.json", | ||
"pytorch_model.bin" | ||
], | ||
} | ||
] | ||
|
||
def download_json(url: str, dest: Path): | ||
res = requests.get(url, allow_redirects=True) | ||
if res.status_code == 200 and res.content: | ||
with dest.open("wb") as f: | ||
f.write(res.content) | ||
else: | ||
print(f"Failed to download {url}. Status code: {res.status_code}") | ||
|
||
def download_weights(baseurl: str, basedest: str, files: list[str]): | ||
basedest = Path(basedest) | ||
start = time.time() | ||
print("downloading to: ", basedest) | ||
basedest.mkdir(parents=True, exist_ok=True) | ||
for f in files: | ||
dest = basedest / f | ||
url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f) | ||
if not dest.exists(): | ||
print("downloading url: ", url) | ||
if dest.suffix == ".json": | ||
download_json(url, dest) | ||
else: | ||
subprocess.check_call(["pget", url, str(dest)], close_fds=False) | ||
print("downloading took: ", time.time() - start) | ||
|
||
class Predictor(BasePredictor): | ||
def setup(self) -> None: | ||
"""Load the model into memory to make running multiple predictions efficient""" | ||
for weight in weights: | ||
download_weights(weight["src"], weight["dest"], weight["files"]) | ||
disable_torch_init() | ||
|
||
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False) | ||
|
||
def predict( | ||
self, | ||
image: Path = Input(description="Input image"), | ||
prompt: str = Input(description="Prompt to use for text generation"), | ||
top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0), | ||
temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0), | ||
max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0), | ||
) -> ConcatenateIterator[str]: | ||
"""Run a single prediction on the model""" | ||
|
||
conv_mode = "llava_v1" | ||
conv = conv_templates[conv_mode].copy() | ||
|
||
image_data = load_image(str(image)) | ||
image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda() | ||
|
||
# loop start | ||
|
||
# just one turn, always prepend image token | ||
inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt | ||
conv.append_message(conv.roles[0], inp) | ||
|
||
conv.append_message(conv.roles[1], None) | ||
prompt = conv.get_prompt() | ||
|
||
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() | ||
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 | ||
keywords = [stop_str] | ||
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) | ||
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0) | ||
|
||
with torch.inference_mode(): | ||
thread = Thread(target=self.model.generate, kwargs=dict( | ||
inputs=input_ids, | ||
images=image_tensor, | ||
do_sample=True, | ||
temperature=temperature, | ||
top_p=top_p, | ||
max_new_tokens=max_tokens, | ||
streamer=streamer, | ||
use_cache=True, | ||
stopping_criteria=[stopping_criteria])) | ||
thread.start() | ||
# workaround: second-to-last token is always " " | ||
# but we want to keep it if it's not the second-to-last token | ||
prepend_space = False | ||
for new_text in streamer: | ||
if new_text == " ": | ||
prepend_space = True | ||
continue | ||
if new_text.endswith(stop_str): | ||
new_text = new_text[:-len(stop_str)].strip() | ||
prepend_space = False | ||
elif prepend_space: | ||
new_text = " " + new_text | ||
prepend_space = False | ||
if len(new_text): | ||
yield new_text | ||
if prepend_space: | ||
yield " " | ||
thread.join() | ||
|
||
|
||
def load_image(image_file): | ||
if image_file.startswith('http') or image_file.startswith('https'): | ||
response = requests.get(image_file) | ||
image = Image.open(BytesIO(response.content)).convert('RGB') | ||
else: | ||
image = Image.open(image_file).convert('RGB') | ||
return image | ||
|