forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Core][VLM] Test registration for OOT multimodal models (vllm-project…
…#8717) Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Alvant <alvasian@yandex.ru>
- Loading branch information
Showing
12 changed files
with
227 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import importlib | ||
import traceback | ||
from typing import Callable | ||
from unittest.mock import patch | ||
|
||
|
||
def find_cuda_init(fn: Callable[[], object]) -> None: | ||
""" | ||
Helper function to debug CUDA re-initialization errors. | ||
If `fn` initializes CUDA, prints the stack trace of how this happens. | ||
""" | ||
from torch.cuda import _lazy_init | ||
|
||
stack = None | ||
|
||
def wrapper(): | ||
nonlocal stack | ||
stack = traceback.extract_stack() | ||
return _lazy_init() | ||
|
||
with patch("torch.cuda._lazy_init", wrapper): | ||
fn() | ||
|
||
if stack is not None: | ||
print("==== CUDA Initialized ====") | ||
print("".join(traceback.format_list(stack)).strip()) | ||
print("==========================") | ||
|
||
|
||
if __name__ == "__main__": | ||
find_cuda_init( | ||
lambda: importlib.import_module("vllm.model_executor.models.llava")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
28 changes: 8 additions & 20 deletions
28
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,14 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from vllm import ModelRegistry | ||
from vllm.model_executor.models.opt import OPTForCausalLM | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
|
||
|
||
class MyOPTForCausalLM(OPTForCausalLM): | ||
|
||
def compute_logits( | ||
self, hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: | ||
# this dummy model always predicts the first token | ||
logits = super().compute_logits(hidden_states, sampling_metadata) | ||
if logits is not None: | ||
logits.zero_() | ||
logits[:, 0] += 1.0 | ||
return logits | ||
|
||
|
||
def register(): | ||
# register our dummy model | ||
# Test directly passing the model | ||
from .my_opt import MyOPTForCausalLM | ||
|
||
if "MyOPTForCausalLM" not in ModelRegistry.get_supported_archs(): | ||
ModelRegistry.register_model("MyOPTForCausalLM", MyOPTForCausalLM) | ||
|
||
# Test passing lazy model | ||
if "MyLlava" not in ModelRegistry.get_supported_archs(): | ||
ModelRegistry.register_model("MyLlava", | ||
"vllm_add_dummy_model.my_llava:MyLlava") |
28 changes: 28 additions & 0 deletions
28
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_llava.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from vllm.inputs import INPUT_REGISTRY | ||
from vllm.model_executor.models.llava import (LlavaForConditionalGeneration, | ||
dummy_data_for_llava, | ||
get_max_llava_image_tokens, | ||
input_processor_for_llava) | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.multimodal import MULTIMODAL_REGISTRY | ||
|
||
|
||
@MULTIMODAL_REGISTRY.register_image_input_mapper() | ||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) | ||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) | ||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) | ||
class MyLlava(LlavaForConditionalGeneration): | ||
|
||
def compute_logits( | ||
self, hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: | ||
# this dummy model always predicts the first token | ||
logits = super().compute_logits(hidden_states, sampling_metadata) | ||
if logits is not None: | ||
logits.zero_() | ||
logits[:, 0] += 1.0 | ||
return logits |
19 changes: 19 additions & 0 deletions
19
tests/plugins/vllm_add_dummy_model/vllm_add_dummy_model/my_opt.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from vllm.model_executor.models.opt import OPTForCausalLM | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
|
||
|
||
class MyOPTForCausalLM(OPTForCausalLM): | ||
|
||
def compute_logits( | ||
self, hidden_states: torch.Tensor, | ||
sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: | ||
# this dummy model always predicts the first token | ||
logits = super().compute_logits(hidden_states, sampling_metadata) | ||
if logits is not None: | ||
logits.zero_() | ||
logits[:, 0] += 1.0 | ||
return logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.