Skip to content

Commit

Permalink
[Neuron][WIP] add log probs calulation in Neuron (#1516)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Jan 30, 2024
1 parent 2c9f08c commit f3f4cfd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def append(self, next_token: int, next_token_text: str):
self.increment_cache_id()

def select(self, input_ids: torch.LongTensor,
logits: torch.Tensor) -> torch.LongTensor:
logits: torch.Tensor) -> (torch.LongTensor, torch.Tensor):
"""Select the next token from the candidate logits.
Args:
Expand All @@ -182,7 +182,8 @@ def select(self, input_ids: torch.LongTensor,
Return:
`torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
"""
return self._selector.select(input_ids, logits)[0]
next_ids, next_log_probs = self._selector.select(input_ids, logits)
return next_ids[0], next_log_probs

def increment_cache_id(self):
self._cache_id += 1
Expand Down Expand Up @@ -374,7 +375,8 @@ def _generate_token(
request_ids.append(request_id)
next_token_logits = outputs.logits[i:i + 1, -1, :]
slot_input_ids = input_ids[i:i + 1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token, next_log_prob = slot.select(slot_input_ids,
next_token_logits)
next_token_text = slot.decoder.decode(next_token.item())
slot.trim_cache_id()
slot.append(next_token, next_token_text)
Expand All @@ -400,7 +402,7 @@ def _generate_token(
request_id=request_id,
prefill_tokens=None,
token_id=next_token,
token_logprob=None,
token_logprob=next_log_prob,
token_text=next_token_text,
token_is_special=(next_token in [self.special_tokens]),
generated_text=generated_text,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def create(cls, input_ids: torch.Tensor,
)

def select(self, input_ids: torch.LongTensor,
logits: torch.Tensor) -> torch.LongTensor:
logits: torch.Tensor) -> (torch.LongTensor, torch.Tensor):
"""Select the next tokens from the candidate logits.
Args:
Expand All @@ -188,10 +188,14 @@ def select(self, input_ids: torch.LongTensor,
`torch.LongTensor`: A `torch.LongTensor` containing the selected tokens.
"""
scores = self.logits_processor(input_ids, logits)
logprobs = torch.log_softmax(scores, -1)
if self.mode == GenerationMode.SAMPLE:
return self._sample(scores)
next_ids = self._sample(scores)
else:
return torch.argmax(scores, dim=-1)
next_ids = torch.argmax(scores, dim=-1)
next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1,
1)).view(-1)
return next_ids, next_logprobs

def _sample(self, scores: torch.Tensor) -> torch.LongTensor:
if self.fast_topk:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ local_model_path = Path("./model")
local_model_path.mkdir(exist_ok=True)
model_name = "facebook/opt-30b"
# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model"]
allow_patterns = ["*.json", "*.pt", "*.bin", "*.txt", "*.model", "*.tiktoken"]

# - Leverage the snapshot library to donload the model since the model is stored in repository using LFS
snapshot_download(
Expand Down

0 comments on commit f3f4cfd

Please sign in to comment.