diff --git a/trl/models/modeling_value_head.py b/trl/models/modeling_value_head.py index 76805d8606..bc9f10d904 100644 --- a/trl/models/modeling_value_head.py +++ b/trl/models/modeling_value_head.py @@ -144,6 +144,7 @@ def forward( input_ids=None, past_key_values=None, attention_mask=None, + return_past_key_values=False, **kwargs, ): r""" @@ -159,6 +160,7 @@ def forward( Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned. kwargs (`dict`, `optional`): Additional keyword arguments, that are passed to the wrapped model. """ @@ -187,7 +189,10 @@ def forward( if lm_logits.dtype != torch.float32: lm_logits = lm_logits.float() - return (lm_logits, loss, value) + if return_past_key_values: + return (lm_logits, loss, value, base_model_output.past_key_values) + else: + return (lm_logits, loss, value) def generate(self, *args, **kwargs): r""" @@ -406,6 +411,7 @@ def forward( input_ids=None, past_key_values=None, attention_mask=None, + return_past_key_values=False, **kwargs, ): kwargs["past_key_values"] = past_key_values @@ -429,7 +435,10 @@ def forward( if lm_logits.dtype != torch.float32: lm_logits = lm_logits.float() - return (lm_logits, loss, value) + if return_past_key_values: + return (lm_logits, loss, value, base_model_output.past_key_values) + else: + return (lm_logits, loss, value) def generate(self, *args, **kwargs): r"""