Skip to content

[Bugfix] Fix LoRA test #18518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/lora/test_lora_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def run_check(fn, args, expected: list):
run_check(llm.add_lora, make_lora_request(12), [12, 9, 10, 11])
run_check(llm.add_lora, make_lora_request(13), [12, 13, 10, 11])

# Remove all LoRAs
# Remove all LoRAs.
run_check(llm.remove_lora, 13, [12, 10, 11])
run_check(llm.remove_lora, 12, [10, 11])
run_check(llm.remove_lora, 11, [10])
Expand Down
136 changes: 72 additions & 64 deletions tests/v1/sample/test_topk_topp_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,40 @@
FLASHINFER_ENABLED = current_platform.is_cuda() and is_flashinfer_available


@pytest.fixture(autouse=True)
def reset_default_device():
"""
Explicitly set the default device, which can affect subsequent tests.
Adding this fixture helps avoid this problem.
"""
original_device = torch.get_default_device()
yield
torch.set_default_device(original_device)


def test_topk_impl_equivalance():

with torch.device(DEVICE):
generator = Generator(device=DEVICE).manual_seed(33)
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(33)

logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

# Random top-k values between 1 and 9.
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
# Random top-k values between 1 and 9.
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)

# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=bool), VOCAB_SIZE)
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
k.masked_fill_(
torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
VOCAB_SIZE)

# Top-k only implementation
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
# Top-k only implementation
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)

# Top-p + top-k
no_op_top_p = torch.tensor([1.0])
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
# Top-p + top-k
no_op_top_p = torch.tensor([1.0])
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)

assert torch.allclose(result1, result2)
assert torch.allclose(result1, result2)


def test_flashinfer_sampler():
Expand All @@ -58,50 +67,49 @@ def test_flashinfer_sampler():
pytest.skip(
"FlashInfer not installed or not available on this platform.")

with torch.device(DEVICE):
generator = Generator(device=DEVICE).manual_seed(42)

# Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

# Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand(
(BATCH_SIZE, ),
generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]

# Sometimes disable top-k (k=vocab_size)
k_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), VOCAB_SIZE)

# Sometimes disable top-p (p=1.0)
p_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), 1.0)

python_logits = apply_top_k_top_p(
logits=logits.clone(),
k=k_values,
p=p_values,
)
python_probs = torch.softmax(python_logits, dim=-1)

# FlashInfer only exposed renorm interfaces for probs so convert first
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
flashinfer_probs = top_k_renorm_probs(
probs=flashinfer_probs,
top_k=k_values,
)
flashinfer_probs = top_p_renorm_probs(
probs=flashinfer_probs,
top_p=p_values,
)

# Compare the results
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
"FlashInfer and Python sampling implementations do not match!"
torch.set_default_device(DEVICE)
generator = Generator(device=DEVICE).manual_seed(42)

# Generate random logits
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)

# Generate various top-k and top-p values
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
p_values = torch.rand(
(BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]

# Sometimes disable top-k (k=vocab_size)
k_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), VOCAB_SIZE)

# Sometimes disable top-p (p=1.0)
p_values.masked_fill_(
torch.randint(0,
2, (BATCH_SIZE, ),
generator=generator,
dtype=torch.bool), 1.0)

python_logits = apply_top_k_top_p(
logits=logits.clone(),
k=k_values,
p=p_values,
)
python_probs = torch.softmax(python_logits, dim=-1)

# FlashInfer only exposed renorm interfaces for probs so convert first
flashinfer_probs = torch.softmax(logits.clone(), dim=-1)
flashinfer_probs = top_k_renorm_probs(
probs=flashinfer_probs,
top_k=k_values,
)
flashinfer_probs = top_p_renorm_probs(
probs=flashinfer_probs,
top_p=p_values,
)

# Compare the results
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
"FlashInfer and Python sampling implementations do not match!"