Skip to content

Commit 6ef2a4f

Browse files
compile random sampling
1 parent df99418 commit 6ef2a4f

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

nanovllm/layers/sampler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@ class Sampler(nn.Module):
77
def __init__(self):
88
super().__init__()
99

10+
@torch.compile
1011
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
11-
logits = logits.float()
12-
greedy_tokens = logits.argmax(dim=-1)
13-
logits.div_(temperatures.unsqueeze(dim=1))
14-
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
15-
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1)
16-
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
12+
logits = logits.float().div_(temperatures.unsqueeze(dim=1))
13+
probs = torch.softmax(logits, dim=-1)
14+
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1).clamp_min_(1e-10)).argmax(dim=-1)
15+
return sample_tokens

nanovllm/sampling_params.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@ class SamplingParams:
66
temperature: float = 1.0
77
max_tokens: int = 64
88
ignore_eos: bool = False
9+
10+
def __post_init__(self):
11+
assert self.temperature > 1e-10, "greedy sampling is not permitted"

0 commit comments

Comments
 (0)