Skip to content

Commit

Permalink
[Core] Change LoRA embedding sharding to support loading methods (#5038)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Jun 7, 2024
1 parent a31cab7 commit ccdc490
Show file tree
Hide file tree
Showing 11 changed files with 661 additions and 129 deletions.
10 changes: 2 additions & 8 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py

- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
Expand Down Expand Up @@ -138,14 +139,7 @@ steps:
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len
- pytest -v -s -x lora/test_long_context.py

- label: Tensorizer Test
#mirror_hardwares: [amd]
Expand Down
21 changes: 21 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import gc
import os
import subprocess
import sys
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import pytest
Expand Down Expand Up @@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield caplog


@pytest.fixture(scope="session")
def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""

try:
out = subprocess.run([
sys.executable, "-c",
"import torch; print(torch.cuda.device_count())"
],
capture_output=True,
check=True,
text=True)
except subprocess.CalledProcessError as e:
logger.warning("Failed to get number of GPUs.", exc_info=e)
return 0
return int(out.stdout.strip())
18 changes: 16 additions & 2 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,24 @@ def cleanup():
ray.shutdown()


@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""

if request.node.get_closest_marker("skip_global_cleanup"):
return False

return True


@pytest.fixture(autouse=True)
def cleanup_fixture():
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
cleanup()
if should_do_global_cleanup_after_test:
cleanup()


@pytest.fixture
Expand Down
219 changes: 217 additions & 2 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch

import pytest
import torch
Expand Down Expand Up @@ -32,7 +33,7 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed

from .utils import DummyLoRAManager
Expand Down Expand Up @@ -427,7 +428,8 @@ def _pretest():
logits_processor = LogitsProcessor(
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
logits_processor, 1024, linear.weight.dtype, linear.weight.device,
None)
lora_logits_processor.create_lora_weights(max_loras, lora_config)

return linear, logits_processor, lora_logits_processor
Expand Down Expand Up @@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,

torch.allclose(ref_q, actual_q)
torch.allclose(ref_k, actual_k)


@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("seed", list(range(256)))
def test_vocab_parallel_embedding_indices(tp_size, seed):
random.seed(seed)
vocab_size = random.randint(4000, 64000)
added_vocab_size = random.randint(0, 1024)
org_vocab_size = vocab_size - added_vocab_size
last_org_vocab_end_index = 0
last_added_vocab_end_index = org_vocab_size
computed_vocab_size = 0
computed_org_vocab_size = 0
computed_added_vocab_size = 0
vocab_size_padded = -1

all_org_tokens = []
all_added_tokens = []
token_ids = []

for tp_rank in range(tp_size):
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
return_value=tp_rank
), patch(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
return_value=tp_size):
vocab_embedding = VocabParallelEmbedding(
vocab_size, 1, org_num_embeddings=org_vocab_size)
vocab_size_padded = vocab_embedding.num_embeddings_padded
shard_indices = vocab_embedding.shard_indices
# Assert that the ranges are contiguous
assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
assert (shard_indices.added_vocab_start_index ==
last_added_vocab_end_index)

# Ensure that we are not exceeding the vocab size
computed_vocab_size += shard_indices.num_elements_padded
computed_org_vocab_size += shard_indices.num_org_elements
computed_added_vocab_size += shard_indices.num_added_elements

# Ensure that the ranges are not overlapping
all_org_tokens.extend(
range(shard_indices.org_vocab_start_index,
shard_indices.org_vocab_end_index))
all_added_tokens.extend(
range(shard_indices.added_vocab_start_index,
shard_indices.added_vocab_end_index))

token_ids.extend(
range(shard_indices.org_vocab_start_index,
shard_indices.org_vocab_end_index))
token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
shard_indices.num_org_elements))
token_ids.extend(
range(shard_indices.added_vocab_start_index,
shard_indices.added_vocab_end_index))
token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
shard_indices.num_added_elements))

last_org_vocab_end_index = shard_indices.org_vocab_end_index
last_added_vocab_end_index = shard_indices.added_vocab_end_index

assert computed_vocab_size == vocab_size_padded
assert computed_org_vocab_size == org_vocab_size
assert computed_added_vocab_size == added_vocab_size

# Ensure that the ranges are not overlapping
assert len(all_org_tokens) == len(set(all_org_tokens))
assert len(all_added_tokens) == len(set(all_added_tokens))
assert not set(all_org_tokens).intersection(set(all_added_tokens))

token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
assert reindex_mapping is not None or tp_size == 1
if reindex_mapping is not None:
reindexed_token_ids = token_ids_tensor[reindex_mapping]
expected = torch.tensor(list(range(0, vocab_size)))
assert reindexed_token_ids[:vocab_size].equal(expected)
assert torch.all(reindexed_token_ids[vocab_size:] == -1)


def test_get_masked_input_and_mask():
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

# base tp 1 case, no padding
modified_x, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=8,
added_vocab_start_index=8,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(x, modified_x)

# tp 2 case, no padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=4,
added_vocab_start_index=8,
added_vocab_end_index=10,
num_org_vocab_padding=0)
modified_x_rank_1, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=8,
added_vocab_start_index=10,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))

# tp 4 case, no padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=2,
added_vocab_start_index=8,
added_vocab_end_index=9,
num_org_vocab_padding=0)
modified_x_rank_1, _ = get_masked_input_and_mask(x,
org_vocab_start_index=2,
org_vocab_end_index=4,
added_vocab_start_index=9,
added_vocab_end_index=10,
num_org_vocab_padding=0)
modified_x_rank_2, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=6,
added_vocab_start_index=10,
added_vocab_end_index=11,
num_org_vocab_padding=0)
modified_x_rank_3, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=6,
org_vocab_end_index=8,
added_vocab_start_index=11,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
assert torch.equal(modified_x_rank_2,
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
assert torch.equal(modified_x_rank_3,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))

# base tp 1 case, with padding
modified_x, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=8,
added_vocab_start_index=8,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(modified_x,
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))

# tp 2 case, with padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=4,
added_vocab_start_index=8,
added_vocab_end_index=10,
num_org_vocab_padding=2)
modified_x_rank_1, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=8,
added_vocab_start_index=10,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))

# tp 4 case, with padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=2,
added_vocab_start_index=8,
added_vocab_end_index=9,
num_org_vocab_padding=2)
modified_x_rank_1, _ = get_masked_input_and_mask(x,
org_vocab_start_index=2,
org_vocab_end_index=4,
added_vocab_start_index=9,
added_vocab_end_index=10,
num_org_vocab_padding=2)
modified_x_rank_2, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=6,
added_vocab_start_index=10,
added_vocab_end_index=11,
num_org_vocab_padding=2)
modified_x_rank_3, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=6,
org_vocab_end_index=8,
added_vocab_start_index=11,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
assert torch.equal(modified_x_rank_2,
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
assert torch.equal(modified_x_rank_3,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
17 changes: 7 additions & 10 deletions tests/lora/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int):
return generated_texts


@pytest.mark.parametrize("tp_size", [1])
def test_llama_lora(sql_lora_files, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
@pytest.mark.parametrize("tp_size", [1, 2, 4])
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
Expand Down Expand Up @@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size):
print("removing lora")


@pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(sql_lora_files):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
if num_gpus_available < 4:
pytest.skip("Not enough GPUs for tensor parallelism 4")

llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,
Expand Down
Loading

0 comments on commit ccdc490

Please sign in to comment.