Skip to content

Commit

Permalink
Hotfixing qwen2 and starcoder2 (which also get clamping). (huggingfac…
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored and yuanwu2017 committed Sep 24, 2024
1 parent bc5a792 commit d580215
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def forward(
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)

hidden_states = self.model(
input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def forward(
elif self.max_past is not None:
# Clamp in decode mode as paged attention requires clamped values whereas the flash attention
# kernel requires the true values
input_lengths = torch.clamp(input_lengths, max=self.max_past_tensor)
input_lengths = input_lengths.clamp(max=self.max_past_tensor)

hidden_states = self.model(
input_ids,
Expand Down

0 comments on commit d580215

Please sign in to comment.