Skip to content

Commit

Permalink
Support for returning past_key_values from the model (#1742)
Browse files Browse the repository at this point in the history
* add support for returning past_key_values from the model

* change order of  keys
  • Loading branch information
idanshen authored Jun 20, 2024
1 parent a57e759 commit ba6abee
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):
r"""
Expand All @@ -159,6 +160,7 @@ def forward(
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the wrapped model.
"""
Expand Down Expand Up @@ -187,7 +189,10 @@ def forward(
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()

return (lm_logits, loss, value)
if return_past_key_values:
return (lm_logits, loss, value, base_model_output.past_key_values)
else:
return (lm_logits, loss, value)

def generate(self, *args, **kwargs):
r"""
Expand Down Expand Up @@ -406,6 +411,7 @@ def forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):
kwargs["past_key_values"] = past_key_values
Expand All @@ -429,7 +435,10 @@ def forward(
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()

return (lm_logits, loss, value)
if return_past_key_values:
return (lm_logits, loss, value, base_model_output.past_key_values)
else:
return (lm_logits, loss, value)

def generate(self, *args, **kwargs):
r"""
Expand Down

0 comments on commit ba6abee

Please sign in to comment.