Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LiuXiaoxuanPKU committed Aug 30, 2024
1 parent 4e16ce3 commit 5e57252
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,6 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
@torch.inference_mode()
def test_compare_nonflashinfer_backend(k: int, vocab_size: int,
batch_size: int, device: str):

def get_seeded_seqs():
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= 1.0
return {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size) if seeded_mask[i]
}

"""
Test the flashinfer and nonflashinfer backend generate
the same output metrics.
Expand All @@ -251,6 +243,13 @@ def get_seeded_seqs():
num_accepted_tokens = []
num_emitted_tokens = []
num_draft_tokens = []

def get_seeded_seqs():
return {
i: torch.Generator(device=device).manual_seed(i)
for i in range(batch_size)
}

for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(disable_bonus_tokens=False,
use_flashinfer=use_flashinfer)
Expand Down

0 comments on commit 5e57252

Please sign in to comment.