From 2879d9dd2909757b12c8369d7f5e17c8f79fde95 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 29 Feb 2024 17:31:47 -0500 Subject: [PATCH] Rs/model integration tests logprobs (#71) SUMMARY: Adds end to end model tests TEST PLAN: Compares logprobs of results from hf model vs vllm model at fp16 and bfloat16 --------- Co-authored-by: Michael Goin --- .github/scripts/run-tests | 2 + tests/conftest.py | 69 ++++++++++++++++++++++++++++ tests/models/test_compressed.py | 6 +-- tests/models/test_models_logprobs.py | 63 +++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 4 deletions(-) create mode 100644 tests/models/test_models_logprobs.py diff --git a/.github/scripts/run-tests b/.github/scripts/run-tests index b75eceb89c92f..059af657290d2 100755 --- a/.github/scripts/run-tests +++ b/.github/scripts/run-tests @@ -64,6 +64,8 @@ do CUDA_VISIBLE_DEVICES=0,1 pytest --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$? elif [[ "${TEST}" == *"distributed"* ]]; then pytest --forked --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$? + elif [[ "${TEST}" == *"models_logprobs"* ]]; then + pytest --forked --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$? else pytest --junitxml=${RESULT_XML} ${TEST} || LOCAL_SUCCESS=$? fi diff --git a/tests/conftest.py b/tests/conftest.py index aef9847f5b843..9ffceaf193b1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -158,6 +158,75 @@ def hf_runner(): return HfRunner +class HfRunnerNM(HfRunner): + + def generate_greedy_logprobs_nm( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + ) -> List[Tuple[List[int], str]]: + all_logprobs = [] + all_output_ids = [] + all_output_strs = [] + + for prompt in prompts: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + output = self.model.generate( + input_ids.cuda(), + use_cache=True, + do_sample=False, + max_new_tokens=max_tokens, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + seq_logprobs = [] + for _, hidden_states in enumerate(output.hidden_states): + last_hidden_states = hidden_states[-1][0] + logits = torch.matmul( + last_hidden_states, + self.model.get_output_embeddings().weight.t(), + ) + if self.model.get_output_embeddings().bias is not None: + logits += self.model.get_output_embeddings( + ).bias.unsqueeze(0) + logprobs = torch.nn.functional.log_softmax(logits, + dim=-1, + dtype=torch.float32) + seq_logprobs.append(logprobs) + + # convert to dict + seq_logprobs_lst = [] + for tok_idx, tok_logprobs in enumerate(seq_logprobs): + # drop prompt logprobs + if tok_idx == 0: + tok_logprobs = tok_logprobs[-1, :].reshape(1, -1) + topk = tok_logprobs.topk(num_logprobs) + + tok_logprobs_dct = {} + for token_id, logprob in zip(topk.indices[0], topk.values[0]): + tok_logprobs_dct[token_id.item()] = logprob.item() + + seq_logprobs_lst.append(tok_logprobs_dct) + + all_logprobs.append(seq_logprobs_lst) + seq_ids = output.sequences[0] + output_len = seq_ids.shape[0] - input_ids.shape[1] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + + +@pytest.fixture +def hf_runner_nm(): + return HfRunnerNM + + class VllmRunner: def __init__( diff --git a/tests/models/test_compressed.py b/tests/models/test_compressed.py index bd106c85852f2..fed9dfb35e881 100644 --- a/tests/models/test_compressed.py +++ b/tests/models/test_compressed.py @@ -2,12 +2,12 @@ Note: sparse kernels do not have bitwise correctness vs the dense models. As a result, in this test, we just confirm that the top selected tokens of the sparse models are in the top N selections of same model running dense. -Run `pytest tests/models/test_sparse.py --forked`. + +Run `pytest tests/models/test_compressed.py`. """ import gc import pytest -import torch from compare_utils import check_logprobs_close MAX_MODEL_LEN = 1024 @@ -44,7 +44,6 @@ def test_models( # Note: deleting just the model does not always free the GPU memory, not sure why. del sparse_model.model.llm_engine.driver_worker del sparse_model - torch.cuda.empty_cache() gc.collect() dense_model = vllm_runner_nm(model_name=model_name, @@ -57,7 +56,6 @@ def test_models( # Note: deleting just the model does not always free the GPU memory, not sure why. del dense_model.model.llm_engine.driver_worker del dense_model - torch.cuda.empty_cache() gc.collect() # loop through the prompts diff --git a/tests/models/test_models_logprobs.py b/tests/models/test_models_logprobs.py new file mode 100644 index 0000000000000..80cbf2a48efc4 --- /dev/null +++ b/tests/models/test_models_logprobs.py @@ -0,0 +1,63 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/models/test_models_logprobs.py --forked`. +""" +import pytest +from compare_utils import check_logprobs_close + +MODEL_MAX_LEN = 1024 + +MODELS = [ + "facebook/opt-125m", + "meta-llama/Llama-2-7b-hf", + "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", + "tiiuae/falcon-7b", + "gpt2", + "bigcode/tiny_starcoder_py", + "EleutherAI/gpt-j-6b", + "EleutherAI/pythia-1b", # Switched to 1b model, 70m model logits too unstable. # noqa + "bigscience/bloom-1b1", # Switched to 1b model, 560m model logits too unstable. # noqa + # "mosaicml/mpt-7b", # Failing on the hf_runner, ignore for now. # noqa + "microsoft/phi-2", + # "stabilityai/stablelm-3b-4e1t", # vLLM bug looking up model in ModelRegistry, ignore for now. # noqa + # "allenai/OLMo-1B", # Failing on the hf_runner, ignore for now. (Wait for https://github.com/allenai/OLMo/pull/451 to land in transformers) # noqa +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16", "half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [3]) +def test_models( + vllm_runner_nm, + hf_runner_nm, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + hf_model = hf_runner_nm(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy_logprobs_nm(example_prompts, + max_tokens, num_logprobs) + + del hf_model + + vllm_model = vllm_runner_nm(model, + dtype=dtype, + max_model_len=MODEL_MAX_LEN) + vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + + del vllm_model.model.llm_engine.driver_worker + del vllm_model + + # loop through the prompts + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf_model", + name_1="vllm_model", + )