You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def generate(self, idx, max_new_tokens):
# idx is (B,T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop idx to the max size of our positional embeddings table
idx_crop = idx[:, -self.context_length:]
# Get predictions
logits, loss = self(idx_crop)
# Get the last time step from logits where the dimensions of the logits are (B,T,C)
logits_last_timestep = logits[:, -1, :]
# Apply softmax to get probabilities
probs = F.softmax(input=logits_last_timestep, dim=-1)
# Sample from the probabilities' distribution.
idx_next = torch.multinomial(input=probs, num_samples=1)
# Append the sampled indexes idx_next to idx
idx = torch.cat((idx, idx_next), dim=1)
return idx
The text was updated successfully, but these errors were encountered:
def generate(self, idx, max_new_tokens):
# idx is (B,T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop idx to the max size of our positional embeddings table
idx_crop = idx[:, -self.context_length:]
# Get predictions
logits, loss = self(idx_crop)
# Get the last time step from logits where the dimensions of the logits are (B,T,C)
logits_last_timestep = logits[:, -1, :]
# Apply softmax to get probabilities
probs = F.softmax(input=logits_last_timestep, dim=-1)
# Sample from the probabilities' distribution.
idx_next = torch.multinomial(input=probs, num_samples=1)
# Append the sampled indexes idx_next to idx
idx = torch.cat((idx, idx_next), dim=1)
return idx
The text was updated successfully, but these errors were encountered: