Skip to content

Commit

Permalink
remove token_idx in needs_tensor_out
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
  • Loading branch information
sywangyi committed Oct 11, 2024
1 parent c6cdce6 commit a282dbe
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions optimum/habana/transformers/generation/stopping_criteria.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def gaudi_EosTokenCriteria_call(
return torch.all(is_done).item()


def needs_tensor_output(token_idx, ignore_eos, eos_token_id) -> bool:
def needs_tensor_output(ignore_eos, eos_token_id) -> bool:
return not ignore_eos and eos_token_id is not None


def gaudi_StoppingCriteriaList_call(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> Union[torch.BoolTensor, bool]:
kwargs["needs_tensor_output"] = needs_tensor_output(
kwargs.get("token_idx", None), kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None)
kwargs.get("ignore_eos", True), kwargs.get("eos_token_id", None)
)
is_done = (
torch.full((input_ids.shape[0],), 0, device=input_ids.device, dtype=torch.int8)
Expand Down

0 comments on commit a282dbe

Please sign in to comment.