Skip to content

Commit

Permalink
Rs/model integration tests logprobs (vllm-project#71)
Browse files Browse the repository at this point in the history
SUMMARY:
Adds end to end model tests

TEST PLAN:
Compares logprobs of results from hf model vs vllm model at fp16 and
bfloat16

---------

Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
robertgshaw2-neuralmagic and mgoin authored Feb 29, 2024
1 parent c23efd4 commit 2879d9d
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/scripts/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ do
CUDA_VISIBLE_DEVICES=0,1 pytest --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
elif [[ "${TEST}" == *"distributed"* ]]; then
pytest --forked --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
elif [[ "${TEST}" == *"models_logprobs"* ]]; then
pytest --forked --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
else
pytest --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
fi
Expand Down
69 changes: 69 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,75 @@ def hf_runner():
return HfRunner


class HfRunnerNM(HfRunner):

def generate_greedy_logprobs_nm(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
all_logprobs = []
all_output_ids = []
all_output_strs = []

for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
)

seq_logprobs = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs)

# convert to dict
seq_logprobs_lst = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)

tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()

seq_logprobs_lst.append(tok_logprobs_dct)

all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))

outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]


@pytest.fixture
def hf_runner_nm():
return HfRunnerNM


class VllmRunner:

def __init__(
Expand Down
6 changes: 2 additions & 4 deletions tests/models/test_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
Note: sparse kernels do not have bitwise correctness vs the dense models.
As a result, in this test, we just confirm that the top selected tokens of the
sparse models are in the top N selections of same model running dense.
Run `pytest tests/models/test_sparse.py --forked`.
Run `pytest tests/models/test_compressed.py`.
"""

import gc
import pytest
import torch
from compare_utils import check_logprobs_close

MAX_MODEL_LEN = 1024
Expand Down Expand Up @@ -44,7 +44,6 @@ def test_models(
# Note: deleting just the model does not always free the GPU memory, not sure why.
del sparse_model.model.llm_engine.driver_worker
del sparse_model
torch.cuda.empty_cache()
gc.collect()

dense_model = vllm_runner_nm(model_name=model_name,
Expand All @@ -57,7 +56,6 @@ def test_models(
# Note: deleting just the model does not always free the GPU memory, not sure why.
del dense_model.model.llm_engine.driver_worker
del dense_model
torch.cuda.empty_cache()
gc.collect()

# loop through the prompts
Expand Down
63 changes: 63 additions & 0 deletions tests/models/test_models_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/models/test_models_logprobs.py --forked`.
"""
import pytest
from compare_utils import check_logprobs_close

MODEL_MAX_LEN = 1024

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"Deci/DeciLM-7b",
"tiiuae/falcon-7b",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-1b", # Switched to 1b model, 70m model logits too unstable. # noqa
"bigscience/bloom-1b1", # Switched to 1b model, 560m model logits too unstable. # noqa
# "mosaicml/mpt-7b", # Failing on the hf_runner, ignore for now. # noqa
"microsoft/phi-2",
# "stabilityai/stablelm-3b-4e1t", # vLLM bug looking up model in ModelRegistry, ignore for now. # noqa
# "allenai/OLMo-1B", # Failing on the hf_runner, ignore for now. (Wait for https://github.com/allenai/OLMo/pull/451 to land in transformers) # noqa
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [3])
def test_models(
vllm_runner_nm,
hf_runner_nm,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
hf_model = hf_runner_nm(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy_logprobs_nm(example_prompts,
max_tokens, num_logprobs)

del hf_model

vllm_model = vllm_runner_nm(model,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)

del vllm_model.model.llm_engine.driver_worker
del vllm_model

# loop through the prompts
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf_model",
name_1="vllm_model",
)

0 comments on commit 2879d9d

Please sign in to comment.