Skip to content

Commit

Permalink
formatPython (#1292)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohithkrn authored Nov 8, 2023
1 parent 11a55c3 commit af7f369
Showing 1 changed file with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -285,19 +285,23 @@ def decode(self) -> List[Generation]:
if input_ids is None:
# Create blank inputs covering all slots (even empty ones)
input_ids = torch.full(
[self.model.batch_size, 1], fill_value=self.tokenizer.eos_token_id, dtype=torch.int64
)
[self.model.batch_size, 1],
fill_value=self.tokenizer.eos_token_id,
dtype=torch.int64)
# input_ids are simply the tokens generated by the last decode or prefill requests (other tokens are cached)
input_ids[i, 0] = slot.next_token
if attention_mask is None:
# Create default mask covering all slots (even empty ones)
attention_mask = torch.zeros(
[self.model.batch_size, slot.attention_mask.size(-1)], dtype=torch.int64
)
[self.model.batch_size,
slot.attention_mask.size(-1)],
dtype=torch.int64)
attention_mask[:, -1] = 1
attention_mask[i, :] = slot.attention_mask
if input_ids is None:
raise ValueError("Unable to decode tokens for non-prefilled batches (probably due to a previous failure)")
raise ValueError(
"Unable to decode tokens for non-prefilled batches (probably due to a previous failure)"
)
return self._generate_token(input_ids, attention_mask)

def _generate_token(
Expand Down

0 comments on commit af7f369

Please sign in to comment.