Skip to content

Commit

Permalink
[Fix] Joining the outputs when running without KV Cache (#1161)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz authored Aug 8, 2023
1 parent c8f6b0e commit a0b6088
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions src/deepsparse/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,30 +603,37 @@ def join_engine_outputs(
:return: A list of joined outputs
"""
tokens, logits = zip(*batch_outputs)
if self.has_cache:
# if the model has kv cache, we need to account for
# the fact that the predicted outputs may have
# different lengths

# find the longest sequence in the batch of tokens
max_len = max([token.shape[1] for token in tokens])

# pad all tokens to the same length
tokens = [
pad_to_fixed_length(
array=prediction,
max_len=max_len,
value=self.tokenizer.pad_token_id,
axis=1,
)
for prediction in tokens
]

# find the longest sequence in the batch of tokens
max_len = max([token.shape[1] for token in tokens])
# find the longest sequence in the batch of logits
max_len = max([logits.shape[1] for logits in logits])

# pad all tokens to the same length
tokens = [
pad_to_fixed_length(
array=prediction,
max_len=max_len,
value=self.tokenizer.pad_token_id,
axis=1,
)
for prediction in tokens
]
tokens = numpy.concatenate(tokens, axis=0)
# pad all logits to the same length
logits = [
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
for single_logits in logits
]

# find the longest sequence in the batch of logits
max_len = max([logits.shape[1] for logits in logits])
# pad all logits to the same length
logits = [
pad_to_fixed_length(array=single_logits, max_len=max_len, axis=1)
for single_logits in logits
]
tokens = numpy.concatenate(tokens, axis=0)
logits = numpy.concatenate(logits, axis=0)

return [tokens, logits]

def _reset_engines_cache(self):
Expand Down

0 comments on commit a0b6088

Please sign in to comment.