Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,3 +1223,145 @@ def test_get_masked_input_and_mask():
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
assert torch.equal(modified_x_rank_3,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))


@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4])
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_lm_head_logits_processor_zero_vocab_padding(dist_init, num_loras,
device, stage) -> None:
"""
Test LogitsProcessorWithLoRA with lora_extra_vocab_size = 0.

This is a regression test to ensure that unembed LoRA works correctly
when no vocab padding is used (lora_extra_vocab_size = 0).
"""
if current_platform.is_cuda_alike():
torch.cuda.set_device(device)

torch.set_default_device(device)
max_loras = 8
vocab_size = 32000
hidden_size = 1024

# Create LoRAConfig with lora_extra_vocab_size = 0
lora_config = LoRAConfig(
max_loras=max_loras,
max_lora_rank=8,
lora_extra_vocab_size=0, # No vocab padding
lora_dtype=torch.float16)

punica_wrapper = get_punica_wrapper(8192, 256, device, max_loras=max_loras)
assert check_punica_wrapper(punica_wrapper)

def _pretest():
# Note: vocab_size + 0 (no extra vocab)
linear = ParallelLMHead(vocab_size,
hidden_size,
vocab_size,
params_dtype=torch.float16)
linear.weight.data = torch.rand_like(linear.weight.data)
logits_processor = LogitsProcessor(vocab_size, vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, hidden_size, linear.weight.dtype,
linear.weight.device, None)
lora_logits_processor.create_lora_weights(max_loras, lora_config)

return linear, logits_processor, lora_logits_processor

for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i)

id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, logits_processor, lora_logits_processor = _pretest()
lora_logits_processor.set_mapping(punica_wrapper)

# Populate LoRAs without embeddings tensor (since extra_vocab_size = 0)
lora_dict, _ = populate_loras(
id_to_index,
layer=lora_logits_processor,
layer_weights=linear.weight,
generate_embeddings_tensor=0, # No embeddings tensor
)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
num_inputs=8 * num_loras,
input_size=(1, hidden_size),
input_range=(0, 1),
input_type=torch.float16,
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
vocab_size,
0, # lora_extra_vocab_size = 0
)

# Test with LoRA
lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
lm_head=linear,
embedding_bias=None)

# Compute expected results
expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = logits_processor._get_logits(hidden_states=input_,
lm_head=linear,
embedding_bias=None)
# Apply LoRA transformation
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)

# Verify results match
rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)

# Test resetting LoRA weights
for slot_idx in range(max_loras):
lora_logits_processor.reset_lora(slot_idx)

# Test without any active LoRAs
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[0],
num_inputs=8 * num_loras,
input_size=(1, hidden_size),
input_range=(0, 1),
input_type=torch.float16,
device=device)
lora_mapping = LoRAMapping(index_mapping,
prompt_mapping,
is_prefill=stage)
punica_wrapper.update_metadata(
lora_mapping,
id_to_index,
max_loras,
vocab_size,
0, # lora_extra_vocab_size = 0
)

lora_result = lora_logits_processor._get_logits(
hidden_states=torch.cat(inputs),
lm_head=linear,
embedding_bias=None)
expected_result = logits_processor._get_logits(
hidden_states=torch.cat(inputs),
lm_head=linear,
embedding_bias=None)

rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(lora_result,
expected_result,
rtol=rtol,
atol=atol)
Loading