-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add llava model support #2153
add llava model support #2153
Conversation
db7731a
to
02284c4
Compare
02284c4
to
3a38396
Compare
Just something to share. Should we adopt the openai v1 api which also supports vision input? It is more generic and supports multiturn vision based conversation.
|
Thanks, I think you are right, we should adopt this api if we surrport the multimodality feature in the However, this is still a experimental feature. and will change the current api alot, I'm not sure if I should do this ,may be creat a sideway entrypoint first? |
Hi @AzureSilent, which model does this PR compatible with? https://huggingface.co/liuhaotian/llava-v1.5-13b seems not work. |
It compatible with the huggingface hf-llava model: llava-v1.5-13b this model’s weight file need to be converted accroding to huggingface's new implementation. This model is not supported by transfotmers >= 4.36. So we should switched to the new arc. |
You need to read the vllm's doc... from vllm import LLM, SamplingParams, LLaVA
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=256) |
Hi @AzureSilent ,I'm using llava-hf/llava-1.5-7b-hf to run you pr, but I got a error |
Sorry, I forgot that, please use the new commit
|
cc58376
to
0759e59
Compare
0759e59
to
a039da9
Compare
Hi @AzureSilent thanks for putting this up! I've been trying to use your implementation and load the llava-hf/llava-1.5-13b-hf model. However, when trying to run the llava_server command ~$ python -m vllm.entrypoints.llava_server --model llava-hf/llava-1.5-13b-hf --tensor-parallel-size 1 --gpu-memory-utilization 0.90
WARNING 01-14 16:23:23 config.py:498] The model's config.json does not contain any of the following keys to determine the original maximum length of the model: ['max_position_embeddings', 'n_positions', 'max_seq_len', 'seq_length', 'max_sequence_length', 'max_seq_length', 'seq_len']. Assuming the model's maximum length is 2048.
INFO 01-14 16:23:23 llm_engine.py:73] Initializing an LLM engine with config: model='llava-hf/llava-1.5-13b-hf', tokenizer='llava-hf/llava-1.5-13b-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, enforce_eager=False, seed=0)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Traceback (most recent call last):
File "/opt/conda/envs/llava_vllm/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/llava_vllm/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/gcpuser/sky_workdir/vllm/vllm/entrypoints/llava_server.py", line 113, in <module>
engine = AsyncLLaVAEngine.from_engine_args(engine_args)
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/async_llm_engine.py", line 494, in from_engine_args
engine = cls(parallel_config.worker_use_ray,
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/async_llm_engine.py", line 267, in __init__
self.engine = self._init_engine(*args, **kwargs)
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/async_llm_engine.py", line 312, in _init_engine
return engine_class(*args, **kwargs)
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/llava_engine.py", line 18, in __init__
super().__init__(*args, **kwargs)
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/llm_engine.py", line 114, in __init__
self._init_workers(distributed_init_method)
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/llm_engine.py", line 150, in _init_workers
self._run_workers(
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/llm_engine.py", line 755, in _run_workers
self._run_workers_in_batch(workers, method, *args, **kwargs))
File "/home/gcpuser/sky_workdir/vllm/vllm/engine/llm_engine.py", line 729, in _run_workers_in_batch
output = executor(*args, **kwargs)
File "/home/gcpuser/sky_workdir/vllm/vllm/worker/worker.py", line 79, in load_model
self.model_runner.load_model()
File "/home/gcpuser/sky_workdir/vllm/vllm/worker/model_runner.py", line 61, in load_model
self.model = get_model(self.model_config)
File "/home/gcpuser/sky_workdir/vllm/vllm/model_executor/model_loader.py", line 36, in get_model
model_class = _get_model_architecture(model_config.hf_config)
File "/home/gcpuser/sky_workdir/vllm/vllm/model_executor/model_loader.py", line 30, in _get_model_architecture
raise ValueError(
ValueError: Model architectures ['LlavaForCausalLM'] are not supported for now. Supported architectures: ['AquilaModel', 'AquilaForCausalLM', 'BaiChuanForCausalLM', 'BaichuanForCausalLM', 'BloomForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'DeciLMForCausalLM', 'FalconForCausalLM', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTJForCausalLM', 'GPTNeoXForCausalLM', 'InternLMForCausalLM', 'LlamaForCausalLM', 'LlavaForConditionalGeneration', 'LLaMAForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MPTForCausalLM', 'OPTForCausalLM', 'PhiForCausalLM', 'QWenLMHeadModel', 'RWForCausalLM', 'YiForCausalLM'] I presume that this happens because the architecture is set to from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "llava-hf/llava-1.5-13b-hf"
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
~$ python -m vllm.entrypoints.llava_server \
--model llava-hf/llava-1.5-13b-hf \
--tensor-parallel-size 1 \
--gpu-memory-utilization 0.90
config.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.09k/1.09k [00:00<00:00, 281kB/s]
WARNING 01-14 16:45:42 config.py:498] The model's config.json does not contain any of the following keys to determine the original maximum length of the model: ['max_position_embeddings', 'n_positions', 'max_seq_len', 'seq_length', 'max_sequence_length', 'max_seq_length', 'seq_len']. Assuming the model's maximum length is 2048.
INFO 01-14 16:45:42 llm_engine.py:73] Initializing an LLM engine with config: model='llava-hf/llava-1.5-13b-hf', tokenizer='llava-hf/llava-1.5-13b-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, quantization=None, enforce_eager=False, seed=0)
tokenizer_config.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.33k/1.33k [00:00<00:00, 363kB/s]
tokenizer.model: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500k/500k [00:00<00:00, 12.4MB/s]
tokenizer.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.84M/1.84M [00:00<00:00, 14.2MB/s]
added_tokens.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41.0/41.0 [00:00<00:00, 25.3kB/s]
special_tokens_map.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 438/438 [00:00<00:00, 273kB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
model-00006-of-00006.safetensors: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.02G/2.02G [00:22<00:00, 89.2MB/s]
model-00002-of-00006.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.97G/4.97G [00:44<00:00, 112MB/s]
model-00001-of-00006.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.96G/4.96G [00:44<00:00, 111MB/s]
model-00003-of-00006.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.88G/4.88G [00:44<00:00, 109MB/s]
model-00004-of-00006.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.93G/4.93G [00:45<00:00, 109MB/s]
model-00005-of-00006.safetensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.93G/4.93G [00:46<00:00, 107MB/s]
INFO 01-14 16:46:40 llm_engine.py:227] # GPU blocks: 842, # CPU blocks: 327██████████████████████████████████████████████████████████████████████| 4.93G/4.93G [00:46<00:00, 164MB/s]
INFO 01-14 16:46:43 model_runner.py:456] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 01-14 16:46:43 model_runner.py:460] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider decreasing `gpu_memory_utilization` or enforcing eager mode.
INFO 01-14 16:46:48 model_runner.py:502] Graph capturing finished in 5 secs.
preprocessor_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 557/557 [00:00<00:00, 147kB/s]
INFO: Started server process [16998]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit) |
Hi @isaac-vidas Thanks for your reply! I think this is just a minor issue caused by inconsistent naming. In the config.json of llava-hf/llava-1.5-13b-hf, the name of the architecture is 'LlavaForCausalLM', but in llava-hf/llava-1.5-7b-hf the architecture is 'LlavaForConditionalGeneration'. I apologize for not noticing this earlier, I will make the code compatible. But for now, all you need to do is just change the architectures in the config.json file to 'LlavaForConditionalGeneration'. This won't impact the inference. |
Great, thanks for the reply and suggestion @AzureSilent ! This PR seems to be working great and easy to get working with the HF LLaVA models, is there any reason why this wasn't merged to the main branch? |
hi, Tried to merge this with the current main branch of vllm on my branch(llava-main) from vllm import SamplingParams, LLaVA
image = 'URL'
llm = LLaVA(model="llava-hf/llava-1.5-7b-hf", tensor_parallel_size=1,gpu_memory_utilization=0.95, enforce_eager=True)
prompts = [ 'prompt1',
'prompt2 <image> \n say',
'prompt3 <image> say something <image> something',
]
sampling_params = SamplingParams(temperature=0.2, top_p=0.95, max_tokens=256)
outputs = llm.generate(prompts, sampling_params, images=[image]*3) # PIL url or base64。 but it has errors: WARNING 01-30 14:55:51 config.py:584] The model's config.json does not contain any of the following keys to determine the original maximum length of the model: ['max_position_embeddings', 'n_positions', 'max_seq_len', 'seq_length', 'max_sequence_length', 'max_seq_length', 'seq_len']. Assuming the model's maximum length is 2048.
INFO 01-30 14:55:51 llm_engine.py:72] Initializing an LLM engine with config: model='llava-hf/llava-1.5-7b-hf', tokenizer='llava-hf/llava-1.5-7b-hf', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=auto, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto, seed=0)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
WARNING 01-30 14:55:55 custom_all_reduce.py:33] Custom allreduce is disabled due to an unsupported world size: 1. Supported world sizes: [2, 4, 6, 8]. To slience this warning, specifydisable_custom_all_reduce=True explicitly.
INFO 01-30 14:55:56 weight_utils.py:164] Using model weights format ['*.safetensors']
INFO 01-30 14:56:01 llm_engine.py:322] # GPU blocks: 3073, # CPU blocks: 512
<vllm.engine.llava_engine.LLaVAEngine object at 0x7fb8751b10c0>
num_workers: 0
Traceback (most recent call last):
File "../vllm/llava-try.py", line 10, in <module>
outputs = llm.generate(prompts, sampling_params, images=[image]*3) # PIL url or base64。
File "../vllm/vllm/entrypoints/llava_llm.py", line 170, in generate
self._add_request(prompt, sampling_params, token_ids, images)
File "../vllm/vllm/entrypoints/llava_llm.py", line 181, in _add_request
self.llm_engine.add_request(request_id,
File "../vllm/vllm/engine/llava_engine.py", line 69, in add_request
worker = self.workers[np.random.randint(num_workers)]
File "mtrand.pyx", line 781, in numpy.random.mtrand.RandomState.randint
File "_bounded_integers.pyx", line 1334, in numpy.random._bounded_integers._rand_int64
ValueError: high <= 0 I am assuming this is because of the workers not being initialised properly. |
got this working on the |
@Aakash-kaushik Thanks for the great work on your branch. Does it support llava-v1.6 models as well? |
Hi @matankley my branch currently only supports the mentioned model, you might have to do some work to add 1.6 |
@matankley while support for LLaVA in vllm is ongoing, you can also look into SGLang. It's built on top of vllm and supports LLaVA 1.5 and 1.6 as well as distributed inference across GPUs and json decoding among other features. |
Thank you @AzureSilent for opening this PR! Closing this as we merged #3042 - Please feel free to comment on it! |
Add support for the LLava model:
Note:
Currently, in order to minimize the impact on existing interfaces, a separate api_server has been used.
If the multimodal mode is widely accepted, they can be merged into one API and be mutually compatible.
Limit:
Usage:
I know there are still many areas that need improvement, but this is a runnable version. If it doesn't meet the release requirements, can this be pulled as a new feature branch? So we can continually improve this.
Thans for your review!