Skip to content

Commit 0eb5409

Browse files
mo374zfinitearthtimo282
authored
v1.2.0 (#31)
* Add vllm as feature and a llm_test_run_script * small fixes in vllm class * differentiate between vllm and api inference * add base llm super class * add changes from PR review * change some VLLM params * add batching to vllm * Add release notes and increase version number * change system prompt --------- Co-authored-by: Tom Zehle <t.zehle@gmail.com> Co-authored-by: Timo Heiß <ti-heiss@t-online.de>
1 parent b5e8adb commit 0eb5409

File tree

8 files changed

+167
-7
lines changed

8 files changed

+167
-7
lines changed

.flake8

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
[flake8]
22
max-line-length = 120
3-
ignore = F401, W503
3+
ignore = E731,E231,E203,E501,F401,W503

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ rsync_exclude.txt
66
__pycache__/
77
temp/
88
dist/
9+
outputs/
910
poetry.lock

docs/release-notes.md

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
# Release Notes
22

3+
## Release v1.2.0
4+
### What's changed
5+
#### Added features
6+
* New LLM wrapper: VLLM for local inference with batches
7+
8+
**Full Changelog**: [here](https://github.com/finitearth/promptolution/compare/v1.1.1...v1.2.0)
9+
310
## Release v1.1.1
411
### What's Changed
512
#### Further Changes:
613
- deleted poetry.lock
7-
- updated transformers dependency: bumped from 4.46.3 to 4.48.0
14+
- updated transformers dependency: bumped from 4.46.3 to 4.48.0
15+
16+
**Full Changelog**: [here](https://github.com/finitearth/promptolution/compare/v1.1.0...v1.1.1)
817

918
## Release v1.1.0
1019
### What's changed
@@ -16,6 +25,8 @@
1625
* improved opros meta-prompt
1726
* added support for python versions from 3.9 onwards (previously 3.11)
1827

28+
**Full Changelog**: [here](https://github.com/finitearth/promptolution/compare/v1.0.1...v1.1.0)
29+
1930
## Release v1.0.1
2031
### What's changed
2132
#### Added features
@@ -24,6 +35,8 @@
2435
#### Further Changes:
2536
* fixed release notes
2637

38+
**Full Changelog**: [here](https://github.com/finitearth/promptolution/compare/v1.0.0...v1.0.1)
39+
2740
## Release v1.0.0
2841
### What's changed
2942
#### Added Features:

promptolution/llms/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,23 @@
33
from .api_llm import APILLM
44
from .base_llm import DummyLLM
55
from .local_llm import LocalLLM
6+
from .vllm import VLLM
67

78

89
def get_llm(model_id: str, *args, **kwargs):
910
"""Factory function to create and return a language model instance based on the provided model_id.
1011
1112
This function supports three types of language models:
1213
1. DummyLLM: A mock LLM for testing purposes.
13-
2. LocalLLM: For running models locally (identified by 'local' in the model_id).
14-
3. APILLM: For API-based models (default if not matching other types).
14+
2. LocalLLM: For running models locally.
15+
3. VLLM: For running models using the vLLM library.
16+
4. APILLM: For API-based models (default if not matching other types).
1517
1618
Args:
1719
model_id (str): Identifier for the model to use. Special cases:
1820
- "dummy" for DummyLLM
1921
- "local-{model_name}" for LocalLLM
22+
- "vllm-{model_name}" for VLLM
2023
- Any other string for APILLM
2124
*args: Variable length argument list passed to the LLM constructor.
2225
**kwargs: Arbitrary keyword arguments passed to the LLM constructor.
@@ -29,4 +32,7 @@ def get_llm(model_id: str, *args, **kwargs):
2932
if "local" in model_id:
3033
model_id = "-".join(model_id.split("-")[1:])
3134
return LocalLLM(model_id, *args, **kwargs)
35+
if "vllm" in model_id:
36+
model_id = "-".join(model_id.split("-")[1:])
37+
return VLLM(model_id, *args, **kwargs)
3238
return APILLM(model_id, *args, **kwargs)

promptolution/llms/api_llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from langchain_core.messages import HumanMessage
1414
from langchain_openai import ChatOpenAI
1515

16+
from promptolution.llms.base_llm import BaseLLM
17+
1618
logger = Logger(__name__)
1719
logger.setLevel(INFO)
1820

@@ -46,7 +48,7 @@ async def invoke_model(prompt, model, semaphore):
4648
await asyncio.sleep(delay)
4749

4850

49-
class APILLM:
51+
class APILLM(BaseLLM):
5052
"""A class to interface with various language models through their respective APIs.
5153
5254
This class supports Claude (Anthropic), GPT (OpenAI), and LLaMA (DeepInfra) models.

promptolution/llms/local_llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
logger = logging.getLogger(__name__)
99
logger.warning(f"Could not import torch or transformers in local_llm.py: {e}")
1010

11+
from promptolution.llms.base_llm import BaseLLM
1112

12-
class LocalLLM:
13+
14+
class LocalLLM(BaseLLM):
1315
"""A class for running language models locally using the Hugging Face Transformers library.
1416
1517
This class sets up a text generation pipeline with specified model parameters

promptolution/llms/vllm.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Module for running language models locally using the vLLM library."""
2+
3+
4+
from logging import INFO, Logger
5+
6+
try:
7+
import torch
8+
from transformers import AutoTokenizer
9+
from vllm import LLM, SamplingParams
10+
except ImportError as e:
11+
import logging
12+
13+
logger = logging.getLogger(__name__)
14+
logger.warning(f"Could not import vllm, torch or transformers in vllm.py: {e}")
15+
16+
from promptolution.llms.base_llm import BaseLLM
17+
18+
logger = Logger(__name__)
19+
logger.setLevel(INFO)
20+
21+
22+
class VLLM(BaseLLM):
23+
"""A class for running language models using the vLLM library.
24+
25+
This class sets up a vLLM inference engine with specified model parameters
26+
and provides a method to generate responses for given prompts.
27+
28+
Attributes:
29+
llm (vllm.LLM): The vLLM inference engine.
30+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the model.
31+
sampling_params (vllm.SamplingParams): Parameters for text generation.
32+
33+
Methods:
34+
get_response: Generate responses for a list of prompts.
35+
"""
36+
37+
def __init__(
38+
self,
39+
model_id: str,
40+
batch_size: int = 64,
41+
max_generated_tokens: int = 256,
42+
temperature: float = 0.1,
43+
top_p: float = 0.9,
44+
model_storage_path: str = None,
45+
token: str = None,
46+
dtype: str = "auto",
47+
tensor_parallel_size: int = 1,
48+
gpu_memory_utilization: float = 0.95,
49+
max_model_len: int = 2048,
50+
trust_remote_code: bool = False,
51+
):
52+
"""Initialize the VLLM with a specific model.
53+
54+
Args:
55+
model_id (str): The identifier of the model to use.
56+
batch_size (int, optional): The batch size for text generation. Defaults to 8.
57+
max_generated_tokens (int, optional): Maximum number of tokens to generate. Defaults to 256.
58+
temperature (float, optional): Sampling temperature. Defaults to 0.1.
59+
top_p (float, optional): Top-p sampling parameter. Defaults to 0.9.
60+
model_storage_path (str, optional): Directory to store the model. Defaults to None.
61+
token: (str, optional): Token for accessing the model - not used in implementation yet.
62+
dtype (str, optional): Data type for model weights. Defaults to "float16".
63+
tensor_parallel_size (int, optional): Number of GPUs for tensor parallelism. Defaults to 1.
64+
gpu_memory_utilization (float, optional): Fraction of GPU memory to use. Defaults to 0.95.
65+
max_model_len (int, optional): Maximum sequence length for the model. Defaults to 2048.
66+
trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
67+
68+
Note:
69+
This method sets up a vLLM engine with specified parameters for efficient inference.
70+
"""
71+
self.batch_size = batch_size
72+
self.dtype = dtype
73+
self.tensor_parallel_size = tensor_parallel_size
74+
self.gpu_memory_utilization = gpu_memory_utilization
75+
self.max_model_len = max_model_len
76+
self.trust_remote_code = trust_remote_code
77+
78+
# Configure sampling parameters
79+
self.sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_generated_tokens)
80+
81+
# Initialize the vLLM engine
82+
self.llm = LLM(
83+
model=model_id,
84+
tokenizer=model_id,
85+
dtype=self.dtype,
86+
tensor_parallel_size=self.tensor_parallel_size,
87+
gpu_memory_utilization=self.gpu_memory_utilization,
88+
max_model_len=self.max_model_len,
89+
download_dir=model_storage_path,
90+
trust_remote_code=self.trust_remote_code,
91+
)
92+
93+
# Initialize tokenizer separately for potential pre-processing
94+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
95+
96+
def get_response(self, inputs: list[str]):
97+
"""Generate responses for a list of prompts using the vLLM engine.
98+
99+
Args:
100+
prompts (list[str]): A list of input prompts.
101+
102+
Returns:
103+
list[str]: A list of generated responses corresponding to the input prompts.
104+
105+
Note:
106+
This method uses vLLM's batched generation capabilities for efficient inference.
107+
"""
108+
prompts = [
109+
self.tokenizer.apply_chat_template(
110+
[
111+
{
112+
"role": "system",
113+
"content": "You are a helpful assistant.",
114+
},
115+
{"role": "user", "content": input},
116+
],
117+
tokenize=False,
118+
)
119+
for input in inputs
120+
]
121+
122+
# generate responses for self.batch_size prompts at the same time
123+
all_responses = []
124+
for i in range(0, len(prompts), self.batch_size):
125+
batch = prompts[i : i + self.batch_size]
126+
outputs = self.llm.generate(batch, self.sampling_params)
127+
responses = [output.outputs[0].text for output in outputs]
128+
all_responses.extend(responses)
129+
130+
return all_responses
131+
132+
def __del__(self):
133+
"""Cleanup method to delete the LLM instance and free up GPU memory."""
134+
del self.llm
135+
torch.cuda.empty_cache()

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "promptolution"
3-
version = "1.1.1"
3+
version = "1.2.0"
44
description = ""
55
authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"]
66
readme = "README.md"
@@ -15,6 +15,7 @@ langchain-community = "^0.2.12"
1515
pandas = "^2.2.2"
1616
tqdm = "^4.66.5"
1717
scikit-learn = "^1.5.2"
18+
vllm = "^0.7.3"
1819

1920
[tool.poetry.group.dev.dependencies]
2021
matplotlib = "^3.9.2"

0 commit comments

Comments
 (0)