Skip to content

Commit

Permalink
[python] Update rolling batch to return only deltas
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Dec 14, 2024
1 parent 96a0efd commit 9dcdcb5
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lmi_dist.arg_utils import VllmEngineArgs
from lmi_dist.init_engine import engine_from_args
from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor
from vllm import SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import AtomicCounter

from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
Expand Down Expand Up @@ -140,6 +140,7 @@ def translate_lmi_dist_params(self, parameters: dict):
:return: The same parameters dict, but with lmi-dist style parameter names.
"""
parameters["output_kind"] = RequestOutputKind.DELTA
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
# If `do_sample` is not provided, force temperature=0.0, i.e. greedy
# else set to user-provided value or default to 1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,47 +91,26 @@ def update_request_cache_with_output(request_cache: OrderedDict,

def update_multiple_sequences(cache, request_output, vllm_request_output):
for completion_output in vllm_request_output.outputs:

sequence_index = completion_output.index
if f"sequence_index_{sequence_index}" not in cache:
cache[f"sequence_index_{sequence_index}"] = {
"curr_length": 0,
"num_generated_tokens": 0
}

if sequence_index not in request_output.sequences:
request_output.sequences[sequence_index] = Sequence()

# set token of the sequence
# previous length of token ids generated
prev_len = cache[f"sequence_index_{sequence_index}"][
'num_generated_tokens']
# curr length of the token ids generated so far
cur_len = len(completion_output.token_ids)
cache[f"sequence_index_{sequence_index}"][
"num_generated_tokens"] = cur_len

# get the newly generated token_ids
new_token_ids = completion_output.token_ids[
prev_len:
cur_len] if prev_len < cur_len else completion_output.token_ids
new_token_ids = completion_output.token_ids

# get the newly generated token texts for speculative decoding
output_token_texts = []
if hasattr(completion_output, "output_token_texts"):
output_token_texts = completion_output.output_token_texts[
prev_len:
cur_len] if prev_len < cur_len else completion_output.output_token_texts
output_token_texts = completion_output.output_token_texts

top_tokens = []
token_texts = []
# calculate log probs and token_texts
if completion_output.logprobs:
new_logprobs_list = completion_output.logprobs[
prev_len:
cur_len] if prev_len < cur_len else completion_output.logprobs
new_logprobs = []
for token_id, logprobs in zip(new_token_ids, new_logprobs_list):
for token_id, logprobs in zip(new_token_ids,
completion_output.logprobs):
new_logprobs.append(logprobs[token_id].logprob)
decoded_token = logprobs[token_id].decoded_token if logprobs[
token_id].decoded_token else ""
Expand All @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
Token(id=token_id_key,
text=logprob.decoded_token,
log_prob=logprob.logprob))

elif new_token_ids:
# TODO: Test and remove this. logprobs is always set 1. This case should never happen.
new_logprobs = [None] * len(new_token_ids)
curr_length = cache[f"sequence_index_{sequence_index}"][
"curr_length"]
token_texts.append(completion_output.text[curr_length:])
token_texts.append(completion_output.text)

if not output_token_texts:
if len(token_texts) != len(new_token_ids):
Expand Down Expand Up @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output):
request_output.sequences[sequence_index].set_next_top_tokens(
top_tokens)

cache[f"sequence_index_{sequence_index}"]["curr_length"] = len(
completion_output.text)


def get_speculative_decoding_metrics_record(
completion_output: CompletionOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from collections import OrderedDict, defaultdict

from vllm import LLMEngine, SamplingParams
from vllm.sampling_params import RequestOutputKind
from vllm.utils import random_uuid, AtomicCounter

from djl_python.request import Request
Expand Down Expand Up @@ -78,6 +79,7 @@ def translate_vllm_params(self, parameters: dict) -> dict:
:return: The same parameters dict, but with VLLM style parameter names.
"""
parameters["output_kind"] = RequestOutputKind.DELTA
parameters["max_tokens"] = parameters.pop("max_new_tokens", 30)
if "seed" in parameters.keys():
parameters["seed"] = int(parameters["seed"])
Expand Down
106 changes: 9 additions & 97 deletions engines/python/setup/djl_python/tests/test_rb_vllm_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
import unittest
import uuid
from dataclasses import dataclass
from typing import List, Optional, Dict, Union
from collections import OrderedDict
Expand All @@ -12,7 +11,7 @@
import djl_python
from djl_python.output_formatter import _json_output_formatter
from djl_python.request import Request
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput
from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token
'''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1'''


Expand Down Expand Up @@ -148,23 +147,10 @@ def __init__(
],
outputs=[
MockCompletionOutput(index=1,
text=' member of',
token_ids=[4292, 302],
text=' of',
token_ids=[302],
cumulative_logprob=-4.3041129764169455,
logprobs=[{
4292:
MockLogprob(logprob=-4.2740092277526855,
rank=4,
decoded_token=' member'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
302:
MockLogprob(logprob=-0.03010374866425991,
rank=1,
Expand All @@ -181,27 +167,10 @@ def __init__(
finish_reason=None,
stop_reason=None),
MockCompletionOutput(index=0,
text=' consolidated',
token_ids=[22968, 601],
text='ated',
token_ids=[601],
cumulative_logprob=-13.402491569519043,
logprobs=[{
22968:
MockLogprob(logprob=-12.117759704589844,
rank=5308,
decoded_token=' consolid'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
17372:
MockLogprob(logprob=-13.409988403320312,
rank=10489,
decoded_token=' crown'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
601:
MockLogprob(logprob=-1.2847318649291992,
rank=2,
Expand Down Expand Up @@ -235,37 +204,10 @@ def __init__(
],
outputs=[
MockCompletionOutput(index=1,
text=' member of the',
token_ids=[4292, 302,
272],
text=' the',
token_ids=[272],
cumulative_logprob=-4.815703457221389,
logprobs=[{
4292:
MockLogprob(logprob=-4.2740092277526855,
rank=4,
decoded_token=' member'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
302:
MockLogprob(logprob=-0.03010374866425991,
rank=1,
decoded_token=' of'),
235290:
MockLogprob(logprob=-2.2026185989379883,
rank=1,
decoded_token='-'),
578:
MockLogprob(logprob=-2.2026185989379883,
rank=2,
decoded_token=' and')
}, {
272:
MockLogprob(logprob=-0.5115904808044434,
rank=1,
Expand All @@ -282,40 +224,10 @@ def __init__(
finish_reason='length',
stop_reason=None),
MockCompletionOutput(index=0,
text=' consolidated or',
token_ids=[22968, 601, 442],
text=' or',
token_ids=[442],
cumulative_logprob=-20.4010648727417,
logprobs=[{
22968:
MockLogprob(logprob=-12.117759704589844,
rank=5308,
decoded_token=' consolid'),
2032:
MockLogprob(logprob=-3.0240092277526855,
rank=1,
decoded_token=' big'),
17372:
MockLogprob(logprob=-13.409988403320312,
rank=10489,
decoded_token=' crown'),
888:
MockLogprob(logprob=-4.4099884033203125,
rank=3,
decoded_token=' new'),
}, {
601:
MockLogprob(logprob=-1.2847318649291992,
rank=2,
decoded_token='ated'),
1028:
MockLogprob(logprob=-0.909731924533844,
rank=1,
decoded_token='ator'),
1162:
MockLogprob(logprob=-0.8929234743118286,
rank=2,
decoded_token=' year')
}, {
442:
MockLogprob(logprob=-6.998573303222656,
rank=188,
Expand Down

0 comments on commit 9dcdcb5

Please sign in to comment.