Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
EngineArgs.add_cli_args(parser)
parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
parser.set_defaults(model="Qwen/Qwen2.5-3B-Instruct")
parser.set_defaults(max_model_len=1024)

# Add sampling params
Expand Down
59 changes: 55 additions & 4 deletions tests/lora/test_lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# https://github.com/vllm-project/vllm/blob/ed10f3cea199a7a1f3532fbe367f5c5479a6cae9/tests/tpu/lora/test_lora.py

import pytest
import vllm
from vllm.lora.request import LoRARequest
Expand All @@ -25,10 +26,11 @@ def use_v1_only(monkeypatch: pytest.MonkeyPatch):
yield


def setup_vllm(num_loras: int) -> vllm.LLM:
def setup_vllm(num_loras: int, num_devices: int = 1) -> vllm.LLM:
return vllm.LLM(model="Qwen/Qwen2.5-3B-Instruct",
max_model_len=256,
max_num_seqs=8,
tensor_parallel_size=num_devices,
enable_lora=True,
max_loras=num_loras,
max_lora_rank=8)
Expand All @@ -49,7 +51,56 @@ def test_single_lora():
"lora_adapter_2", 2,
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(max_tokens=256,
sampling_params=vllm.SamplingParams(max_tokens=16,
temperature=0),
lora_request=lora_request)[0].outputs[0].text

answer = output.strip()[0]

assert answer.isdigit()
assert int(answer) == 2


def test_single_lora_spmd():
"""
This test ensures we can run a single LoRA adapter on the TPU backend.
We run "Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter" which
will force Qwen2.5-3B-Instruct to claim 1+1=2.
"""
# max_loras = 1
# engine_args = EngineArgs(
# model="Qwen/Qwen2.5-3B-Instruct",
# max_model_len=256,
# max_num_seqs=8,
# enable_lora=True,
# max_loras=max_loras,
# max_lora_rank=8,
# )
# vllm_config = engine_args.create_engine_config()
# with set_current_vllm_config(vllm_config):
# temp_file = tempfile.mkstemp()[1]
# init_distributed_environment(
# 1,
# 0,
# local_rank=0,
# distributed_init_method=f"file://{temp_file}",
# backend="gloo")
# ensure_model_parallel_initialized(1, 1)

# num_devices = jax.local_device_count() # why does this line cause hanging.
# To test SPMD multi-chip case, only num_device=2 works for this model Qwen2.5-3B-Instruct.
# This is because this model has kv_head=2. https://github.com/vllm-project/tpu_commons/blob/a489e59c5b3a4d5c28e93775d5323970eecd66c9/tpu_commons/layers/jax/attention_interface.py#L275 here we shard the num_kv_heads. Only 2 can divide the num_kv_heads in this case.
num_devices = 2
print(f'xw32 using TP={num_devices}')
llm = setup_vllm(1, num_devices)

prompt = "What is 1+1? \n"

lora_request = LoRARequest(
"lora_adapter_2", 2,
"Username6568/Qwen2.5-3B-Instruct-1_plus_1_equals_2_adapter")
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(max_tokens=16,
temperature=0),
lora_request=lora_request)[0].outputs[0].text

Expand Down Expand Up @@ -82,7 +133,7 @@ def test_lora_hotswapping():
for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
max_tokens=16, temperature=0),
lora_request=req)[0].outputs[0].text
answer = output.strip()[0]

Expand Down Expand Up @@ -112,7 +163,7 @@ def test_multi_lora():
for i, req in enumerate(lora_requests):
output = llm.generate(prompt,
sampling_params=vllm.SamplingParams(
max_tokens=256, temperature=0),
max_tokens=16, temperature=0),
lora_request=req)[0].outputs[0].text

answer = output.strip()[0]
Expand Down
2 changes: 2 additions & 0 deletions tpu_commons/models/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def sharded_ragged_paged_attention(
v_scale: float | None = None,
):
"""Shards along KV heads."""
# nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128)
qkv_spec = P(None, "model", None)
kv_cache_spec = P(None, None, "model")
in_specs = (
Expand Down Expand Up @@ -86,6 +87,7 @@ def attention(
md = attention_metadata

# (T, N, H)
# nonspmd(tp=1):q.shape=(16,16,128),k.shape=(16,2,128),kv_cache.shape=(40660,16,2,2,128)
output, kv_cache = sharded_ragged_paged_attention(
head_dim_original**-0.5, mesh, attention_chunk_size, q_scale, k_scale,
v_scale)(
Expand Down
2 changes: 2 additions & 0 deletions tpu_commons/models/vllm/jax_linear_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def torch_to_jax_param(
tensor = tensor.astype(jax_dtype)

if fused:
# In non-lora qkv layer, tensor.shape=[3072, 2048], output_sizes=[2048, 512, 512], n_shards=4, dim=0
# sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('model', None), memory_kind=device)
tensor = reorder_concatenated_tensor_for_sharding(
tensor, output_sizes, n_shards, dim)
tensor = jax.device_put(tensor, sharding)
Expand Down
1 change: 1 addition & 0 deletions tpu_commons/models/vllm/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, vllm_config: VllmConfig, mesh: Mesh, layer: LinearBase):
"Unsupported linear layer type of %s. Can potentially yield "
" bad performance.", type(layer))

# non-lora: for qkv_parallel_linear, weight_sharding is PartitionSpec('model', None)
self.bias_sharding = P(self.weight_sharding[0])
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)

Expand Down
44 changes: 42 additions & 2 deletions tpu_commons/models/vllm/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,47 @@ def _shard_column_parallel_linear_lora(

def _shard_qkv_parallel_linear_lora(layer: MergedQKVParallelLinearWithLoRA,
mesh: Mesh) -> None:
_shard_base_linear_lora(layer, mesh)
# mesh=Mesh(axis_sizes=(1, 2), axis_names=('data', 'model'), axis_types=(Auto, Auto))
# NOTE: lora_a_stacked[i] has shape [max_loras, 1, num_out, num_in]
sharded_lora_a_tpu = torch.nn.ParameterList()
sharded_lora_b_tpu = torch.nn.ParameterList()
sharded_lora_bias_tpu = torch.nn.ParameterList()

assert layer.n_slices > 0, "layer.n_slices should be greater than 0"
mesh_lora_b_shape = (1, 1) + (mesh.shape['data'], mesh.shape['model'])
mesh_lora_b_axis = ('replica_num_lora', 'replica', 'data', 'model')
lora_b_mesh = jax.make_mesh(
mesh_lora_b_shape, mesh_lora_b_axis,
devices=mesh.devices[0]) # mesh.devices=[[device0, ..device_n]]
lora_b_partition_spec = P(None, None, 'model', None)
lora_b_sharding = NamedSharding(lora_b_mesh, lora_b_partition_spec)

mesh_lora_bias_shape = (1, 1) + (mesh.shape['model'], )
mesh_lora_bias_axis = ('replica_num_lora', 'replica', 'model')
lora_bias_mesh = jax.make_mesh(
mesh_lora_bias_shape, mesh_lora_bias_axis,
devices=mesh.devices[0]) # mesh.devices=[[device0, ..device_n]]
lora_bias_partition_spec = P(None, None, 'model')
lora_bias_sharding = NamedSharding(lora_bias_mesh,
lora_bias_partition_spec)

for i in range(layer.n_slices):
sharded_lora_a_tpu.append(
_shard_tensor_to_tpu_replicated(layer.lora_a_stacked[i], mesh))

sharded_lora_b_tpu.append(
_convert_to_torchax_and_shard(layer.lora_b_stacked[i],
lora_b_sharding))

if layer.lora_bias_stacked is not None:
sharded_lora_bias_tpu.append(
_convert_to_torchax_and_shard(layer.lora_bias_stacked[i],
lora_bias_sharding))

layer.lora_a_stacked = sharded_lora_a_tpu
layer.lora_b_stacked = sharded_lora_b_tpu
if layer.lora_bias_stacked is not None:
layer.lora_bias_stacked = sharded_lora_bias_tpu


def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
Expand All @@ -152,7 +192,7 @@ def _shard_row_parallel_linear_lora(layer: RowParallelLinearWithLoRA,
def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
for path, module in model.named_modules():
for module_type, sharding_func in MODULE_TYPE_TO_SHARDING_FUNC:
if isinstance(module, module_type):
if type(module) is module_type:
logger.debug("shard %s with %s", path, sharding_func)
sharding_func(module, mesh)
break