Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Rs/model integration tests logprobs #71

Merged
merged 51 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
184e7e7
added test_sparse.py
robertgshaw2-neuralmagic Feb 23, 2024
25222f8
added sparsity test in clean pr
robertgshaw2-neuralmagic Feb 23, 2024
f567cd8
Update tests/conftest.py
mgoin Feb 23, 2024
b1d4ca7
added files
robertgshaw2-neuralmagic Feb 26, 2024
1d39dae
added sparse memory tests
robertgshaw2-neuralmagic Feb 26, 2024
366c567
added model test with logprobs
robertgshaw2-neuralmagic Feb 26, 2024
44f7b2a
addressed comments
robertgshaw2-neuralmagic Feb 26, 2024
37eb492
formats
robertgshaw2-neuralmagic Feb 26, 2024
3417974
renamed
robertgshaw2-neuralmagic Feb 26, 2024
14d6235
addressed comments
robertgshaw2-neuralmagic Feb 26, 2024
ede3f23
formats
robertgshaw2-neuralmagic Feb 26, 2024
16b2727
renamed
robertgshaw2-neuralmagic Feb 26, 2024
3c26da0
readded model test
robertgshaw2-neuralmagic Feb 26, 2024
bef2668
added model logprobs test file
robertgshaw2-neuralmagic Feb 26, 2024
28533f5
Update conftest.py
robertgshaw2-neuralmagic Feb 26, 2024
608f8ff
Update compare_utils.py
robertgshaw2-neuralmagic Feb 26, 2024
b30bfc0
Update __init__.py
robertgshaw2-neuralmagic Feb 26, 2024
a449161
added test_sparse.py
robertgshaw2-neuralmagic Feb 23, 2024
60dfe82
added sparsity test in clean pr
robertgshaw2-neuralmagic Feb 23, 2024
a90e1e1
Update tests/conftest.py
mgoin Feb 23, 2024
30eea98
added files
robertgshaw2-neuralmagic Feb 26, 2024
8eae93c
added sparse memory tests
robertgshaw2-neuralmagic Feb 26, 2024
e76d602
addressed comments
robertgshaw2-neuralmagic Feb 26, 2024
ed8ff4e
formats
robertgshaw2-neuralmagic Feb 26, 2024
bacaee5
renamed
robertgshaw2-neuralmagic Feb 26, 2024
3e7245b
Merge branch 'rs/sparse-integration-test-clean' of github.com:neuralm…
robertgshaw2-neuralmagic Feb 26, 2024
a19e694
added model test with logprobs
robertgshaw2-neuralmagic Feb 26, 2024
7e9f21f
readded model test
robertgshaw2-neuralmagic Feb 26, 2024
cab0560
added model logprobs test file
robertgshaw2-neuralmagic Feb 26, 2024
1c1d9c2
Update conftest.py
robertgshaw2-neuralmagic Feb 26, 2024
8ec5bc5
Update compare_utils.py
robertgshaw2-neuralmagic Feb 26, 2024
db24b25
Update __init__.py
robertgshaw2-neuralmagic Feb 26, 2024
5815e32
Merge branch 'rs/model-integration-tests-logprobs' of github.com:neur…
robertgshaw2-neuralmagic Feb 26, 2024
2648e68
./format.sh
robertgshaw2-neuralmagic Feb 26, 2024
9dc3555
Merge branch 'main' into rs/sparse-integration-test-clean
robertgshaw2-neuralmagic Feb 27, 2024
4cdde9b
Merge branch 'main' into rs/sparse-integration-test-clean
robertgshaw2-neuralmagic Feb 28, 2024
48f7ee4
Update test_compressed.py
robertgshaw2-neuralmagic Feb 28, 2024
4b6abca
Update test_cache.py
robertgshaw2-neuralmagic Feb 28, 2024
f4dfa62
Update test_attention.py
robertgshaw2-neuralmagic Feb 28, 2024
cd7c6fa
Update test_attention.py
robertgshaw2-neuralmagic Feb 28, 2024
31ed094
Update test_cache.py
robertgshaw2-neuralmagic Feb 28, 2024
a580bf7
Update test_attention.py
robertgshaw2-neuralmagic Feb 28, 2024
9ed6f71
Update test_cache.py
robertgshaw2-neuralmagic Feb 28, 2024
4cdbc5e
format
robertgshaw2-neuralmagic Feb 28, 2024
8a675e5
Merge branch 'rs/sparse-integration-test-clean' into rs/model-integra…
robertgshaw2-neuralmagic Feb 28, 2024
0d4ad60
Merge branch 'main' into rs/model-integration-tests-logprobs
robertgshaw2-neuralmagic Feb 28, 2024
0670821
Update requirements-dev.txt
robertgshaw2-neuralmagic Feb 28, 2024
26deb20
Update test_attention.py
robertgshaw2-neuralmagic Feb 28, 2024
7325140
Update test_cache.py
robertgshaw2-neuralmagic Feb 28, 2024
f73702f
updated logprobs test to run --forked; not sure why this is needed bu…
robertgshaw2-neuralmagic Feb 29, 2024
1bbaf72
added
robertgshaw2-neuralmagic Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/scripts/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ do
CUDA_VISIBLE_DEVICES=0,1 pytest --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
elif [[ "${TEST}" == *"distributed"* ]]; then
pytest --forked --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
elif [[ "${TEST}" == *"models_logprobs"* ]]; then
pytest --forked --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
else
pytest --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$?
fi
Expand Down
69 changes: 69 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,75 @@ def hf_runner():
return HfRunner


class HfRunnerNM(HfRunner):

def generate_greedy_logprobs_nm(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
all_logprobs = []
all_output_ids = []
all_output_strs = []

for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
)

seq_logprobs = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs)

# convert to dict
seq_logprobs_lst = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)

tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()

seq_logprobs_lst.append(tok_logprobs_dct)

all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))

outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]


@pytest.fixture
def hf_runner_nm():
return HfRunnerNM


class VllmRunner:

def __init__(
Expand Down
6 changes: 2 additions & 4 deletions tests/models/test_compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
Note: sparse kernels do not have bitwise correctness vs the dense models.
As a result, in this test, we just confirm that the top selected tokens of the
sparse models are in the top N selections of same model running dense.
Run `pytest tests/models/test_sparse.py --forked`.

Run `pytest tests/models/test_compressed.py`.
"""

import gc
import pytest
import torch
from compare_utils import check_logprobs_close

MAX_MODEL_LEN = 1024
Expand Down Expand Up @@ -44,7 +44,6 @@ def test_models(
# Note: deleting just the model does not always free the GPU memory, not sure why.
del sparse_model.model.llm_engine.driver_worker
del sparse_model
torch.cuda.empty_cache()
gc.collect()

dense_model = vllm_runner_nm(model_name=model_name,
Expand All @@ -57,7 +56,6 @@ def test_models(
# Note: deleting just the model does not always free the GPU memory, not sure why.
del dense_model.model.llm_engine.driver_worker
del dense_model
torch.cuda.empty_cache()
gc.collect()

# loop through the prompts
Expand Down
63 changes: 63 additions & 0 deletions tests/models/test_models_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.

Run `pytest tests/models/test_models_logprobs.py --forked`.
"""
import pytest
from compare_utils import check_logprobs_close

MODEL_MAX_LEN = 1024

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"Deci/DeciLM-7b",
"tiiuae/falcon-7b",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-1b", # Switched to 1b model, 70m model logits too unstable. # noqa
"bigscience/bloom-1b1", # Switched to 1b model, 560m model logits too unstable. # noqa
# "mosaicml/mpt-7b", # Failing on the hf_runner, ignore for now. # noqa
"microsoft/phi-2",
# "stabilityai/stablelm-3b-4e1t", # vLLM bug looking up model in ModelRegistry, ignore for now. # noqa
# "allenai/OLMo-1B", # Failing on the hf_runner, ignore for now. (Wait for https://github.com/allenai/OLMo/pull/451 to land in transformers) # noqa
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16", "half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [3])
def test_models(
vllm_runner_nm,
hf_runner_nm,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
hf_model = hf_runner_nm(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy_logprobs_nm(example_prompts,
max_tokens, num_logprobs)

del hf_model

vllm_model = vllm_runner_nm(model,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)

del vllm_model.model.llm_engine.driver_worker
del vllm_model

# loop through the prompts
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf_model",
name_1="vllm_model",
)
Loading