File tree Expand file tree Collapse file tree 2 files changed +8
-6
lines changed Expand file tree Collapse file tree 2 files changed +8
-6
lines changed Original file line number Diff line number Diff line change @@ -7,10 +7,9 @@ class Sampler(nn.Module):
77 def __init__ (self ):
88 super ().__init__ ()
99
10+ @torch .compile
1011 def forward (self , logits : torch .Tensor , temperatures : torch .Tensor ):
11- logits = logits .float ()
12- greedy_tokens = logits .argmax (dim = - 1 )
13- logits .div_ (temperatures .unsqueeze (dim = 1 ))
14- probs = torch .softmax (logits , dim = - 1 , dtype = torch .float )
15- sample_tokens = probs .div_ (torch .empty_like (probs ).exponential_ (1 ) + 1e-10 ).argmax (dim = - 1 )
16- return torch .where (temperatures == 0 , greedy_tokens , sample_tokens )
12+ logits = logits .float ().div_ (temperatures .unsqueeze (dim = 1 ))
13+ probs = torch .softmax (logits , dim = - 1 )
14+ sample_tokens = probs .div_ (torch .empty_like (probs ).exponential_ (1 ).clamp_min_ (1e-10 )).argmax (dim = - 1 )
15+ return sample_tokens
Original file line number Diff line number Diff line change @@ -6,3 +6,6 @@ class SamplingParams:
66 temperature : float = 1.0
77 max_tokens : int = 64
88 ignore_eos : bool = False
9+
10+ def __post_init__ (self ):
11+ assert self .temperature > 1e-10 , "greedy sampling is not permitted"
You can’t perform that action at this time.
0 commit comments