Skip to content

Commit aa7577c

Browse files
author
sanchit-gandhi
committed
finalise
1 parent 8507ee0 commit aa7577c

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/transformers/models/whisper/generation_whisper.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -956,17 +956,21 @@ def split_by_batch_index(values, key, batch_idx, is_shortform):
956956
if not is_shortform:
957957
# we don't save `past_key_values` as this is too costly for longform
958958
return None
959+
elif isinstance(values, EncoderDecoderCache):
960+
all_past_key_values = []
961+
for layer_idx in range(self.config.decoder_layers):
962+
layer_past_key_values = []
963+
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
964+
for v in [cache_cls.key_cache, cache_cls.value_cache]:
965+
layer_past_key_values.append(v[layer_idx][None].cpu())
966+
all_past_key_values.append(tuple(layer_past_key_values))
967+
return tuple(all_past_key_values)
959968
else:
960969
return tuple(tuple(w[batch_idx][None].cpu() for w in values[v]) for v in range(len(values)))
961970

962971
return values[batch_idx].cpu()
963972

964973
sequence_tokens = seek_outputs["sequences"]
965-
966-
if hasattr(seek_outputs, "past_key_values") and seek_outputs.past_key_values is not None:
967-
if isinstance(seek_outputs["past_key_values"], EncoderDecoderCache):
968-
seek_outputs.past_key_values = seek_outputs.past_key_values.to_legacy_cache()
969-
970974
seek_outputs = [
971975
{k: split_by_batch_index(v, k, i, is_shortform) for k, v in seek_outputs.items()}
972976
for i in range(sequence_tokens.shape[0])

0 commit comments

Comments
 (0)