Skip to content

Commit

Permalink
[CI] Make mistral tests pass (vllm-project#4596)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored May 8, 2024
1 parent d7740ea commit f6a5930
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ steps:
#mirror_hardwares: [amd]
commands:
- bash ../.buildkite/download-images.sh
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
- pytest -v -s models --ignore=models/test_llava.py

- label: Llava Test
#mirror_hardwares: [amd]
Expand Down
62 changes: 62 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,68 @@ def generate_greedy_logprobs(
all_logprobs.append(seq_logprobs)
return all_logprobs

def generate_greedy_logprobs_limit(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
all_logprobs = []
all_output_ids = []
all_output_strs = []

for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
output = self.model.generate(
input_ids.cuda(),
use_cache=True,
do_sample=False,
max_new_tokens=max_tokens,
output_hidden_states=True,
return_dict_in_generate=True,
)

seq_logprobs = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
last_hidden_states,
self.model.get_output_embeddings().weight.t(),
)
if getattr(self.model.get_output_embeddings(), "bias",
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs)

# convert to dict
seq_logprobs_lst = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
topk = tok_logprobs.topk(num_logprobs)

tok_logprobs_dct = {}
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
tok_logprobs_dct[token_id.item()] = logprob.item()

seq_logprobs_lst.append(tok_logprobs_dct)

all_logprobs.append(seq_logprobs_lst)
seq_ids = output.sequences[0]
output_len = seq_ids.shape[0] - input_ids.shape[1]
output_ids = seq_ids[-output_len:]
all_output_ids.append(output_ids.tolist())
all_output_strs.append(self.tokenizer.decode(output_ids))

outputs = zip(all_output_ids, all_output_strs, all_logprobs)
return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]

def __del__(self):
del self.model
cleanup()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_big_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

MODELS = [
"meta-llama/Llama-2-7b-hf",
# "mistralai/Mistral-7B-v0.1", # Broken
# "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py
# "Deci/DeciLM-7b", # Broken
# "tiiuae/falcon-7b", # Broken
"EleutherAI/gpt-j-6b",
Expand Down
33 changes: 18 additions & 15 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,40 @@
"""
import pytest

from tests.models.utils import check_logprobs_close

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.skip(
"Two problems: 1. Failing correctness tests. 2. RuntimeError: expected "
"scalar type BFloat16 but found Half (only in CI).")
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_long_prompts,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
# TODO(sang): Sliding window should be tested separately.
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
del vllm_model

for i in range(len(example_long_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
5 changes: 3 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def _forward(
key_pass = key[..., self.rotary_dim:]

self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
positions.device, dtype=query.dtype)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
Expand Down Expand Up @@ -143,7 +143,8 @@ def forward(
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
Expand Down

0 comments on commit f6a5930

Please sign in to comment.