Skip to content

Commit 3683d85

Browse files
liutongxuanguocuimi
authored andcommitted
[model] support vision language model llava. (#178)
(cherry picked from commit 437be3f)
1 parent 7aeb7fa commit 3683d85

32 files changed

+2933
-3
lines changed

python/tests/llava_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#!/usr/bin/env python3
2+
3+
import torch
4+
from scalellm import VLM, SamplingParameter, StoppingCriteria
5+
6+
def test_pixel_value_llava_generate():
7+
vlm = VLM(
8+
model="llava-hf/llava-1.5-7b-hf",
9+
image_input_type="pixel_values",
10+
image_token_id=32000,
11+
image_input_shape="1,3,336,336",
12+
image_feature_size=576,
13+
)
14+
15+
prompt = "<image>" * 576 + (
16+
"\nUSER: What is the content of this image?\nASSISTANT:")
17+
18+
# This should be provided by another online or offline component.
19+
image = torch.load("images/stop_sign_pixel_values.pt")
20+
21+
output = vlm.generate(images, prompt)
22+
print(o.outputs[0].text)
23+
24+
def main():
25+
test_pixel_value_llava_generate()
26+
27+
if __name__ == "__main__":
28+
main()

scalellm/_C/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ from scalellm._C.llm_handler import LLMHandler, Message, Priority
22
from scalellm._C.output import (LogProb, LogProbData, RequestOutput,
33
SequenceOutput, Status, StatusCode, Usage)
44
from scalellm._C.sampling_params import SamplingParams
5+
from scalellm._C.vlm_handler import VLMHandler
56

67
# Defined in scalellm/csrc/module.cpp
78
def get_metrics() -> str: ...
@@ -18,5 +19,6 @@ __all__ = [
1819
"StatusCode",
1920
"Usage",
2021
"LLMHandler",
22+
"VLMHandler",
2123
"get_metrics",
2224
]

scalellm/_C/vlm_handler.pyi

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Callable, List, Optional
2+
3+
import torch
4+
5+
from scalellm._C.llm_handler import Future, Priority
6+
from scalellm._C.output import RequestOutput
7+
from scalellm._C.sampling_params import SamplingParams
8+
9+
class VLMHandler:
10+
class Options:
11+
def __init__(self) -> None: ...
12+
def __repr__(self) -> str: ...
13+
model_path: str
14+
devices: Optional[str]
15+
block_size: int
16+
max_cache_size: int
17+
max_memory_utilization: float
18+
enable_prefix_cache: bool
19+
enable_cuda_graph: bool
20+
cuda_graph_max_seq_len: int
21+
cuda_graph_batch_sizes: Optional[List[int]]
22+
max_tokens_per_batch: int
23+
max_seqs_per_batch: int
24+
num_handling_threads: int
25+
image_input_type: str
26+
image_token_id: int
27+
image_input_shape: str
28+
image_feature_size: int
29+
30+
def __init__(self, options: Options) -> None: ...
31+
def __repr__(self) -> str: ...
32+
def schedule_async(
33+
self,
34+
image: torch.Tensor,
35+
prompt: str,
36+
sp: SamplingParams,
37+
priority: Priority,
38+
stream: bool,
39+
callback: Callable[[RequestOutput], bool],
40+
) -> Future: ...
41+
def start(self) -> None: ...
42+
def stop(self) -> None: ...
43+
def run_until_complete(self) -> None: ...
44+
def reset(self) -> None: ...
45+
# helper functions
46+
def encode(self, text: str) -> List[int]: ...
47+
def decode(self, tokens: List[int], skip_special_tokens: bool) -> str: ...

scalellm/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from scalellm._C import (LLMHandler, LogProb, LogProbData, Message, Priority,
1414
RequestOutput, SamplingParams, SequenceOutput, Status,
15-
StatusCode, Usage, get_metrics)
15+
StatusCode, Usage, VLMHandler, get_metrics)
1616
from scalellm.errors import ValidationError
1717
from scalellm.llm import LLM
1818
from scalellm.llm_engine import AsyncLLMEngine, OutputAsyncStream, OutputStream
@@ -34,5 +34,6 @@
3434
"StatusCode",
3535
"Usage",
3636
"LLMHandler",
37+
"VLMHandler",
3738
"get_metrics",
3839
]

scalellm/csrc/module.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace py = pybind11;
1111
extern void init_sampling_params(py::module_& m);
1212
extern void init_output(py::module_& m);
1313
extern void init_llm_handler(py::module_& m);
14+
extern void init_vlm_handler(py::module_& m);
1415

1516
// NOLINTNEXTLINE
1617
static std::string get_metrics() { return Metrics::Instance().GetString(); }
@@ -26,6 +27,7 @@ PYBIND11_MODULE(PY_MODULE_NAME, m) {
2627
init_sampling_params(m);
2728
init_output(m);
2829
init_llm_handler(m);
30+
init_vlm_handler(m);
2931
}
3032

31-
} // namespace llm::csrc
33+
} // namespace llm::csrc

scalellm/csrc/vlm_handler.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
#include "handlers/vlm_handler.h"
2+
3+
#include <pybind11/functional.h>
4+
#include <pybind11/pybind11.h>
5+
#include <pybind11/stl.h>
6+
#include <pybind11/stl_bind.h>
7+
8+
namespace llm::csrc {
9+
namespace py = pybind11;
10+
using namespace pybind11::literals;
11+
12+
void init_vlm_handler(py::module_& m) {
13+
py::enum_<Priority>(m, "Priority")
14+
.value("DEFAULT", Priority::NORMAL)
15+
.value("LOW", Priority::LOW)
16+
.value("NORMAL", Priority::NORMAL)
17+
.value("HIGH", Priority::HIGH)
18+
.export_values();
19+
20+
py::class_<std::future<bool>>(m, "Future")
21+
.def("wait",
22+
&std::future<bool>::wait,
23+
py::call_guard<py::gil_scoped_release>())
24+
.def("get",
25+
&std::future<bool>::get,
26+
py::call_guard<py::gil_scoped_release>());
27+
28+
auto vlm_handler =
29+
py::class_<VLMHandler>(m, "VLMHandler")
30+
.def(py::init<const VLMHandler::Options&>(), py::arg("options"))
31+
.def("schedule_async",
32+
&VLMHandler::schedule_async,
33+
py::call_guard<py::gil_scoped_release>())
34+
.def("start",
35+
&VLMHandler::start,
36+
py::call_guard<py::gil_scoped_release>())
37+
.def("stop",
38+
&VLMHandler::stop,
39+
py::call_guard<py::gil_scoped_release>())
40+
.def("run_until_complete",
41+
&VLMHandler::run_until_complete,
42+
py::call_guard<py::gil_scoped_release>())
43+
.def("encode",
44+
&VLMHandler::encode,
45+
py::call_guard<py::gil_scoped_release>())
46+
.def("decode",
47+
&VLMHandler::decode,
48+
py::call_guard<py::gil_scoped_release>())
49+
.def("reset",
50+
&VLMHandler::reset,
51+
py::call_guard<py::gil_scoped_release>())
52+
.def("__repr__", [](const VLMHandler& self) {
53+
return "VLMHandler({})"_s.format(self.options());
54+
});
55+
56+
// VLMHandler::Options
57+
py::class_<VLMHandler::Options>(vlm_handler, "Options")
58+
.def(py::init())
59+
.def_readwrite("model_path", &VLMHandler::Options::model_path_)
60+
.def_readwrite("devices", &VLMHandler::Options::devices_)
61+
.def_readwrite("block_size", &VLMHandler::Options::block_size_)
62+
.def_readwrite("max_cache_size", &VLMHandler::Options::max_cache_size_)
63+
.def_readwrite("max_memory_utilization",
64+
&VLMHandler::Options::max_memory_utilization_)
65+
.def_readwrite("enable_prefix_cache",
66+
&VLMHandler::Options::enable_prefix_cache_)
67+
.def_readwrite("enable_cuda_graph",
68+
&VLMHandler::Options::enable_cuda_graph_)
69+
.def_readwrite("cuda_graph_max_seq_len",
70+
&VLMHandler::Options::cuda_graph_max_seq_len_)
71+
.def_readwrite("cuda_graph_batch_sizes",
72+
&VLMHandler::Options::cuda_graph_batch_sizes_)
73+
.def_readwrite("max_tokens_per_batch",
74+
&VLMHandler::Options::max_tokens_per_batch_)
75+
.def_readwrite("max_seqs_per_batch",
76+
&VLMHandler::Options::max_seqs_per_batch_)
77+
.def_readwrite("num_handling_threads",
78+
&VLMHandler::Options::num_handling_threads_)
79+
.def_readwrite("image_input_type",
80+
&VLMHandler::Options::image_input_type_)
81+
.def_readwrite("image_token_id", &VLMHandler::Options::image_token_id_)
82+
.def_readwrite("image_input_shape",
83+
&VLMHandler::Options::image_input_shape_)
84+
.def_readwrite("image_feature_size",
85+
&VLMHandler::Options::image_feature_size_)
86+
.def("__repr__", [](const VLMHandler::Options& self) {
87+
return "Options(model_path={}, devices={}, "
88+
"block_size={}, max_cache_size={}, "
89+
"max_memory_utilization={}, enable_prefix_cache={}, "
90+
"enable_cuda_graph={}, cuda_graph_max_seq_len={}, "
91+
"cuda_graph_batch_sizes={}, "
92+
"max_tokens_per_batch={}, max_seqs_per_batch={}, "
93+
"num_handling_threads={}, "
94+
"image_input_type={}, image_token_id={},
95+
"image_input_shape={}, image_feature_size={})"_s.format(
96+
self.model_path_,
97+
self.devices_,
98+
self.block_size_,
99+
self.max_cache_size_,
100+
self.max_memory_utilization_,
101+
self.enable_prefix_cache_,
102+
self.enable_cuda_graph_,
103+
self.cuda_graph_max_seq_len_,
104+
self.cuda_graph_batch_sizes_,
105+
self.max_tokens_per_batch_,
106+
self.max_seqs_per_batch_,
107+
self.num_handling_threads_,
108+
self.image_input_type_,
109+
self.image_token_id_,
110+
self.image_input_shape_,
111+
self.image_feature_size_);
112+
});
113+
}
114+
115+
} // namespace llm::csrc

scalellm/vlm.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import os
2+
from typing import List, Optional
3+
4+
import torch
5+
6+
from scalellm._C import Priority, RequestOutput, SamplingParams, VLMHandler
7+
from scalellm.downloader import download_hf_model
8+
from scalellm.errors import ValidationError
9+
10+
11+
class VLM:
12+
def __init__(
13+
self,
14+
model: str,
15+
revision: Optional[str] = None,
16+
allow_patterns: Optional[str] = None,
17+
cache_dir: Optional[str] = None,
18+
convert_to_safetensors: bool = False,
19+
devices: Optional[str] = None,
20+
block_size: int = 16,
21+
max_cache_size: int = 20 * 1024 * 1024 * 1024,
22+
max_memory_utilization: float = 0.9,
23+
enable_prefix_cache: bool = True,
24+
enable_cuda_graph: bool = True,
25+
cuda_graph_max_seq_len: int = 2048,
26+
cuda_graph_batch_sizes: Optional[List[int]] = None,
27+
max_tokens_per_batch: int = 409600, # a big number to disable chunked prefill
28+
max_seqs_per_batch: int = 2048, # a big number for better throughput
29+
num_handling_threads: int = 4,
30+
# vision encoder configuration
31+
image_input_type: Optional[str] = None,
32+
image_token_id: Optional[int] = None,
33+
image_input_shape: Optional[str] = None,
34+
image_feature_size: Optional[int] = None,
35+
) -> None:
36+
# download hf model if it does not exist
37+
self._model = model
38+
model_path = model
39+
if not os.path.exists(model_path):
40+
model_path = download_hf_model(
41+
repo_id=model_path,
42+
revision=revision,
43+
allow_patterns=allow_patterns,
44+
cache_dir=cache_dir,
45+
convert_to_safetensors=convert_to_safetensors,
46+
)
47+
48+
options = VLMHandler.Options()
49+
options.model_path = model_path
50+
options.devices = devices
51+
options.block_size = block_size
52+
options.max_cache_size = max_cache_size
53+
options.max_memory_utilization = max_memory_utilization
54+
options.enable_prefix_cache = enable_prefix_cache
55+
options.enable_cuda_graph = enable_cuda_graph
56+
options.cuda_graph_max_seq_len = cuda_graph_max_seq_len
57+
options.cuda_graph_batch_sizes = cuda_graph_batch_sizes
58+
options.max_tokens_per_batch = max_tokens_per_batch
59+
options.max_seqs_per_batch = max_seqs_per_batch
60+
options.num_handling_threads = num_handling_threads
61+
options.image_input_type = image_input_type
62+
options.image_token_id = image_token_id
63+
options.image_input_shape = image_input_shape
64+
options.image_feature_size = image_feature_size
65+
# create the LLM handler
66+
self._handler = VLMHandler(options)
67+
68+
def generate(
69+
self,
70+
image: torch.Tensor = None,
71+
prompt: str = None,
72+
sampling_params: Optional[SamplingParams] = None,
73+
priority: Priority = Priority.NORMAL,
74+
wait_for_schedule: bool = True,
75+
) -> RequestOutput:
76+
# use default sampling parameters if not provided
77+
if sampling_params is None:
78+
sampling_params = SamplingParams()
79+
80+
output = None
81+
def callback(async_output: RequestOutput) -> bool:
82+
#output = async_output
83+
return True
84+
85+
# schedule the batch requests
86+
future = self._handler.schedule_async(
87+
image, prompt, sampling_params, priority, False, callback
88+
)
89+
90+
# wait for batch request to be scheduled
91+
if wait_for_schedule:
92+
future.wait()
93+
94+
# run until all scheduled requsts complete
95+
self._handler.run_until_complete()
96+
97+
# throw an exception if there is any error
98+
if output is None:
99+
raise RuntimeError("Request failed, no output received")
100+
if output.status is not None and not output.status.ok:
101+
raise ValidationError(output.status.code, output.status.message)
102+
# carry over the prompt to the output
103+
output.prompt = prompt
104+
return output
105+
106+
def encode(self, text: str) -> List[int]:
107+
return self._handler.encode(text)
108+
109+
def decode(
110+
self, tokens: List[int], skip_special_tokens: bool = True
111+
) -> Optional[str]:
112+
return self._handler.decode(tokens, skip_special_tokens)
113+
114+
def __del__(self):
115+
self._handler.reset()
116+
117+
def __enter__(self):
118+
return self
119+
120+
def __exit__(self, *args):
121+
self.__del__()
122+
return False
123+
124+
def __repr__(self) -> str:
125+
if self._draft_model:
126+
return f"VLM(model={self._model}, draft_model={self._draft_model})"
127+
return f"VLM(model={self._model})"

src/engine/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@ cc_library(
1010
batch.h
1111
model_runner.h
1212
worker.h
13+
vlm_worker.h
1314
engine.h
1415
llm_engine.h
16+
vlm_engine.h
1517
SRCS
1618
utils.cpp
1719
batch.cpp
1820
model_runner.cpp
1921
worker.cpp
22+
vlm_worker.cpp
2023
llm_engine.cpp
24+
vlm_engine.cpp
2125
DEPS
2226
torch
2327
:common

0 commit comments

Comments
 (0)