Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove additional float/clone() for perf #1374

Open
wants to merge 2 commits into
base: transformers_future
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions optimum/habana/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,6 @@ def _prepare_cache_for_generation(
Changes:
- change the default from DynamicCache to tuples
"""

cache_name = "past_key_values" if "mamba" not in self.__class__.__name__.lower() else "cache_params"
requires_cross_attention_cache = (
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
Expand Down Expand Up @@ -1801,7 +1800,7 @@ def _contrastive_search(
logit_for_next_step = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
else:
# .float() is needed to retain precision for later logits manipulations
logit_for_next_step = outputs.logits[:, -1, :].float()
logit_for_next_step = outputs.logits[:, -1, :]

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
Expand Down Expand Up @@ -1968,7 +1967,7 @@ def _contrastive_search(
full_hidden_states = outputs.hidden_states

# .float() is needed to retain precision for later logits manipulations
logits = outputs.logits[:, -1, :].float()
logits = outputs.logits[:, -1, :]
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)

# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
Expand Down Expand Up @@ -2355,7 +2354,7 @@ def _sample(
if token_idx is not None and outputs.logits.shape[-2] > 1:
# case1 (w/o KV caching): outputs.logits.shape: [batch_size, max_length, vocab_size]
if self.config.is_encoder_decoder:
next_token_logits = outputs.logits[:, token_idx - 1, :].float()
next_token_logits = outputs.logits[:, token_idx - 1, :]
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
else:
if model_kwargs.get("num_virtual_tokens", 0) > 0:
Expand All @@ -2370,7 +2369,7 @@ def _sample(
next_token_scores = logits_processor(input_ids, next_token_logits)
else:
# .float() is needed to retain precision for later logits manipulations
next_token_logits = outputs.logits[:, -1, :].float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider keeping this

next_token_logits = outputs.logits[:, -1, :]
if token_idx is not None and self.config.is_encoder_decoder:
# case2 (with KV caching): outputs.logits.shape: [batch_size, 1, vocab_size]
next_token_scores = logits_processor(input_ids[:, :token_idx], next_token_logits)
Expand Down Expand Up @@ -2814,7 +2813,7 @@ def expand_if_needed(tensor, new_size, value, dim=-1):
else:
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
else:
next_token_logits = outputs.logits[:, -1, :].float()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't normally run to here

next_token_logits = outputs.logits[:, -1, :]

next_token_scores = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
Expand Down Expand Up @@ -3260,7 +3259,7 @@ def _constrained_beam_search(
else:
next_token_logits = torch.index_select(outputs.logits, -2, token_idx - 1).squeeze(-2)
else:
next_token_logits = outputs.logits[:, -1, :].float()
next_token_logits = outputs.logits[:, -1, :]

next_token_scores = torch.nn.functional.log_softmax(
next_token_logits, dim=-1
Expand Down Expand Up @@ -3580,8 +3579,7 @@ def _assisted_decoding(

# 2.3. Process the new logits
# .float() is needed to retain precision for later logits manipulations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, I think you should at least remove this comment

new_logits = outputs.logits[:, -candidate_length - 1 :].float() # excludes the input prompt if present
next_token_logits = new_logits.clone()
new_logits = outputs.logits[:, -candidate_length - 1 :]
if len(logits_processor) > 0:
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_processor(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])
Expand Down