Question about implementation of top-k sampling (5.3.2 Top-k sampling) #326
-
Hi @rasbt, There is the following code snippet used in the book as top-k sampling implementation: new_logits = torch.where(
condition=next_token_logits < top_logits[-1],
input=torch.tensor(float('-inf')),
other=next_token_logits) Could you please help with these questions:
new_logits = torch.full_like(next_token_logits, -torch.inf)
new_logits[top_pos] = next_token_logits[top_pos]
new_logits Also probably it is a little bit faster: Thank you. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
I think there maybe no difference between them, you can test this simply in REPL: >>> torch.inf == math.inf == np.inf == float('inf')
True
>>> -torch.inf == -math.inf == -np.inf == -float('inf')
True |
Beta Was this translation helpful? Give feedback.
-
Good call. I definitely could have used Overall, I like your alternative implementation. Thanks for sharing that! I think the fact that it doesn't allow duplication can be seen as a pro or con. In my implementation, if you have a top 3 setting and duplicates like in [0.412314, 0.412314, -0.5, 0.1 0.2, 1.0, 0.8, ...] it will not strictly be top 3 anymore but top 3+, e.g., [0.412314, 0.412314, 1.0, 0.8, ...] and the sampling will treat both tokens with equal probability. In your implementation, it will choose one over the other (I think the one that has the lower index due to the I wouldn't say one behavior is "more correct" or "better" than the other; it's just based on what you want. In any case, I added a code comment with your alternative to the notebook in case readers want to follow this discussion here. Thanks for sharing this! |
Beta Was this translation helpful? Give feedback.
Good call. I definitely could have used
torch.inf
. It's just muscle memory at this point because I believe it didn't exist in early versions of PyTorch (pre-2.0 or so).Overall, I like your alternative implementation. Thanks for sharing that! I think the fact that it doesn't allow duplication can be seen as a pro or con.
In my implementation, if you have a top 3 setting and duplicates like in
it will not strictly be top 3 anymore but top 3+, e.g.,
and the sampling will treat both tokens with equal probability. In your implementation, it will choose one over the other (I think the one that has the lower…