-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance lora tests with more layer and rank variations #3243
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
221aaee
enhance lora tests
tterrysun d0cc1c9
add test fixture
tterrysun ccb1872
update requirements
tterrysun b35b7f1
remove redundant tests
tterrysun 02ec9f9
minor fix
tterrysun 2c820da
verify all tokens and refactor
tterrysun 997f6a5
minorfix
tterrysun 0baf6ba
minor fix
tterrysun cd2211e
check 32 tokens
tterrysun fa1f6f0
minor fix
tterrysun c80ad53
minor refactor
tterrysun 991c577
Merge branch 'main' into lora_test_enhancement
tterrysun 2fda771
minor fix
tterrysun 1e48f26
minor fixes
tterrysun b430726
ablation study on ci
tterrysun 4543e8f
use smaller model
tterrysun ce3dd58
add doc string
tterrysun 0033822
remove redundant code
tterrysun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ einops # required for MPT | |
openai | ||
requests | ||
ray | ||
peft | ||
|
||
# Benchmarking | ||
aiohttp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
from typing import List, Optional | ||
import peft | ||
import pytest | ||
from random import sample | ||
import tempfile | ||
from transformers import AutoModelForCausalLM | ||
|
||
import vllm | ||
from vllm.lora.request import LoRARequest | ||
from .conftest import cleanup | ||
|
||
MODEL_PATH = "Felladrin/Llama-68M-Chat-v1" | ||
PROMPTS = [ | ||
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]", | ||
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]", | ||
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]", | ||
] | ||
|
||
|
||
def get_lora_model(model_id: str, target_modules: List[str], rank: int): | ||
model = AutoModelForCausalLM.from_pretrained(model_id) | ||
lora_config = peft.tuners.lora.LoraConfig(target_modules, rank) | ||
lora_model = peft.PeftModel(model, lora_config) | ||
return lora_model | ||
|
||
|
||
def do_sample(llm, | ||
lora_path: Optional[str] = None, | ||
lora_id: Optional[int] = None, | ||
logprobs: int = 0, | ||
n_tokens: int = 256): | ||
prompts = PROMPTS | ||
sampling_params = vllm.SamplingParams(temperature=0, | ||
max_tokens=n_tokens, | ||
logprobs=logprobs, | ||
stop=["[/assistant]"]) | ||
outputs = llm.generate( | ||
prompts, | ||
sampling_params, | ||
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) | ||
if lora_id else None) | ||
# Print the outputs. | ||
generated_texts = [] | ||
generated_logprobs = [] | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
generated_texts.append(generated_text) | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
generated_logprobs.append([ | ||
list(logprob.keys()) for out in output.outputs | ||
for logprob in out.logprobs | ||
]) | ||
return generated_logprobs if logprobs else generated_texts | ||
|
||
|
||
SUPPORTED_MODULES = [ | ||
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", | ||
"lm_head" | ||
] | ||
TARGET_MODULES_LIST = [] | ||
for length in range(2, 6): | ||
TARGET_MODULES_LIST.extend( | ||
[sample(SUPPORTED_MODULES, length) for _ in range(3)]) | ||
|
||
|
||
# Test the correctness when layer and rank are varied | ||
# step 1: init a base model and serve with LoRA to get the reference results | ||
# step 2: merge the same LoRA to the base model, serve the merged model | ||
# step 3: compare the results from step 1 and step 2 | ||
@pytest.mark.parametrize("tp_size", [1]) | ||
@pytest.mark.parametrize("target_modules", TARGET_MODULES_LIST) | ||
@pytest.mark.parametrize("rank", [8, 16, 32, 64]) | ||
def test_layer_variation_correctness(tp_size, target_modules, rank): | ||
llm = vllm.LLM(MODEL_PATH, | ||
enable_lora=True, | ||
max_num_seqs=16, | ||
max_loras=4, | ||
tensor_parallel_size=tp_size, | ||
worker_use_ray=True) | ||
model = get_lora_model(MODEL_PATH, target_modules, rank) | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
model.save_pretrained(tmpdir) | ||
merged_probs = do_sample(llm, tmpdir, 1, logprobs=5, n_tokens=32) | ||
del llm | ||
cleanup() | ||
reference_id_sets = [set(prob[0]) for prob in merged_probs] | ||
|
||
model = get_lora_model(MODEL_PATH, target_modules, rank) | ||
with tempfile.TemporaryDirectory() as tmpdir: | ||
merged_model = model.merge_and_unload() | ||
merged_model.save_pretrained(tmpdir) | ||
llm = vllm.LLM(tmpdir, | ||
tokenizer=MODEL_PATH, | ||
enable_lora=False, | ||
max_num_seqs=16, | ||
tensor_parallel_size=tp_size, | ||
worker_use_ray=True) | ||
probs = do_sample(llm, logprobs=5, n_tokens=32) | ||
del llm | ||
cleanup() | ||
# verify the top-5 tokens are identical for each token | ||
id_sets = [set(prob[0]) for prob in probs] | ||
assert id_sets == reference_id_sets |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I currently don't understand this function -- what are the lora model weights that are actually applied on top of the
meta-llama/Llama-2-7b-hf
base model?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's a default initialized lora, we use the merged one as golden reference to verify the correctness, the lora weights won't matter as long as we're using the same one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you point to where in the docs it says it is a default LoRA and what it is? That part was not clear to me (maybe add a comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if i understand correctly this is the init config https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py#L158-L159