From 9dcdcb557861623e496a2cbc539913ae4353d95b Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Fri, 13 Dec 2024 18:45:22 -0800 Subject: [PATCH] [python] Update rolling batch to return only deltas --- .../rolling_batch/lmi_dist_rolling_batch.py | 3 +- .../rolling_batch/rolling_batch_vllm_utils.py | 37 +----- .../rolling_batch/vllm_rolling_batch.py | 2 + .../djl_python/tests/test_rb_vllm_utils.py | 106 ++---------------- 4 files changed, 18 insertions(+), 130 deletions(-) diff --git a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py index d43cf6c9ba..fbf25a3f06 100644 --- a/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/lmi_dist_rolling_batch.py @@ -19,7 +19,7 @@ from lmi_dist.arg_utils import VllmEngineArgs from lmi_dist.init_engine import engine_from_args from lmi_dist.seq2seq_engine import Seq2SeqPreprocessor -from vllm import SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import AtomicCounter from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params @@ -140,6 +140,7 @@ def translate_lmi_dist_params(self, parameters: dict): :return: The same parameters dict, but with lmi-dist style parameter names. """ + parameters["output_kind"] = RequestOutputKind.DELTA parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) # If `do_sample` is not provided, force temperature=0.0, i.e. greedy # else set to user-provided value or default to 1.0 diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py index bad3cc8eb6..ee7be75b7c 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch_vllm_utils.py @@ -91,47 +91,26 @@ def update_request_cache_with_output(request_cache: OrderedDict, def update_multiple_sequences(cache, request_output, vllm_request_output): for completion_output in vllm_request_output.outputs: - sequence_index = completion_output.index - if f"sequence_index_{sequence_index}" not in cache: - cache[f"sequence_index_{sequence_index}"] = { - "curr_length": 0, - "num_generated_tokens": 0 - } if sequence_index not in request_output.sequences: request_output.sequences[sequence_index] = Sequence() - # set token of the sequence - # previous length of token ids generated - prev_len = cache[f"sequence_index_{sequence_index}"][ - 'num_generated_tokens'] - # curr length of the token ids generated so far - cur_len = len(completion_output.token_ids) - cache[f"sequence_index_{sequence_index}"][ - "num_generated_tokens"] = cur_len - # get the newly generated token_ids - new_token_ids = completion_output.token_ids[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.token_ids + new_token_ids = completion_output.token_ids # get the newly generated token texts for speculative decoding output_token_texts = [] if hasattr(completion_output, "output_token_texts"): - output_token_texts = completion_output.output_token_texts[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.output_token_texts + output_token_texts = completion_output.output_token_texts top_tokens = [] token_texts = [] # calculate log probs and token_texts if completion_output.logprobs: - new_logprobs_list = completion_output.logprobs[ - prev_len: - cur_len] if prev_len < cur_len else completion_output.logprobs new_logprobs = [] - for token_id, logprobs in zip(new_token_ids, new_logprobs_list): + for token_id, logprobs in zip(new_token_ids, + completion_output.logprobs): new_logprobs.append(logprobs[token_id].logprob) decoded_token = logprobs[token_id].decoded_token if logprobs[ token_id].decoded_token else "" @@ -141,13 +120,10 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): Token(id=token_id_key, text=logprob.decoded_token, log_prob=logprob.logprob)) - elif new_token_ids: # TODO: Test and remove this. logprobs is always set 1. This case should never happen. new_logprobs = [None] * len(new_token_ids) - curr_length = cache[f"sequence_index_{sequence_index}"][ - "curr_length"] - token_texts.append(completion_output.text[curr_length:]) + token_texts.append(completion_output.text) if not output_token_texts: if len(token_texts) != len(new_token_ids): @@ -186,9 +162,6 @@ def update_multiple_sequences(cache, request_output, vllm_request_output): request_output.sequences[sequence_index].set_next_top_tokens( top_tokens) - cache[f"sequence_index_{sequence_index}"]["curr_length"] = len( - completion_output.text) - def get_speculative_decoding_metrics_record( completion_output: CompletionOutput, diff --git a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py index 66abbf811e..71f80258cd 100644 --- a/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/vllm_rolling_batch.py @@ -13,6 +13,7 @@ from collections import OrderedDict, defaultdict from vllm import LLMEngine, SamplingParams +from vllm.sampling_params import RequestOutputKind from vllm.utils import random_uuid, AtomicCounter from djl_python.request import Request @@ -78,6 +79,7 @@ def translate_vllm_params(self, parameters: dict) -> dict: :return: The same parameters dict, but with VLLM style parameter names. """ + parameters["output_kind"] = RequestOutputKind.DELTA parameters["max_tokens"] = parameters.pop("max_new_tokens", 30) if "seed" in parameters.keys(): parameters["seed"] = int(parameters["seed"]) diff --git a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py index 41486fb20c..86e9bcd0c4 100644 --- a/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py +++ b/engines/python/setup/djl_python/tests/test_rb_vllm_utils.py @@ -1,6 +1,5 @@ import sys import unittest -import uuid from dataclasses import dataclass from typing import List, Optional, Dict, Union from collections import OrderedDict @@ -12,7 +11,7 @@ import djl_python from djl_python.output_formatter import _json_output_formatter from djl_python.request import Request -from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token, RequestInput +from djl_python.request_io import TextGenerationOutput, TextInput, Sequence, Token '''These Mock classes are in compliance with vllm RequestOutput version 0.5.3.post1''' @@ -148,23 +147,10 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of', - token_ids=[4292, 302], + text=' of', + token_ids=[302], cumulative_logprob=-4.3041129764169455, logprobs=[{ - 4292: - MockLogprob(logprob=-4.2740092277526855, - rank=4, - decoded_token=' member'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { 302: MockLogprob(logprob=-0.03010374866425991, rank=1, @@ -181,27 +167,10 @@ def __init__( finish_reason=None, stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated', - token_ids=[22968, 601], + text='ated', + token_ids=[601], cumulative_logprob=-13.402491569519043, logprobs=[{ - 22968: - MockLogprob(logprob=-12.117759704589844, - rank=5308, - decoded_token=' consolid'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 17372: - MockLogprob(logprob=-13.409988403320312, - rank=10489, - decoded_token=' crown'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { 601: MockLogprob(logprob=-1.2847318649291992, rank=2, @@ -235,37 +204,10 @@ def __init__( ], outputs=[ MockCompletionOutput(index=1, - text=' member of the', - token_ids=[4292, 302, - 272], + text=' the', + token_ids=[272], cumulative_logprob=-4.815703457221389, logprobs=[{ - 4292: - MockLogprob(logprob=-4.2740092277526855, - rank=4, - decoded_token=' member'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { - 302: - MockLogprob(logprob=-0.03010374866425991, - rank=1, - decoded_token=' of'), - 235290: - MockLogprob(logprob=-2.2026185989379883, - rank=1, - decoded_token='-'), - 578: - MockLogprob(logprob=-2.2026185989379883, - rank=2, - decoded_token=' and') - }, { 272: MockLogprob(logprob=-0.5115904808044434, rank=1, @@ -282,40 +224,10 @@ def __init__( finish_reason='length', stop_reason=None), MockCompletionOutput(index=0, - text=' consolidated or', - token_ids=[22968, 601, 442], + text=' or', + token_ids=[442], cumulative_logprob=-20.4010648727417, logprobs=[{ - 22968: - MockLogprob(logprob=-12.117759704589844, - rank=5308, - decoded_token=' consolid'), - 2032: - MockLogprob(logprob=-3.0240092277526855, - rank=1, - decoded_token=' big'), - 17372: - MockLogprob(logprob=-13.409988403320312, - rank=10489, - decoded_token=' crown'), - 888: - MockLogprob(logprob=-4.4099884033203125, - rank=3, - decoded_token=' new'), - }, { - 601: - MockLogprob(logprob=-1.2847318649291992, - rank=2, - decoded_token='ated'), - 1028: - MockLogprob(logprob=-0.909731924533844, - rank=1, - decoded_token='ator'), - 1162: - MockLogprob(logprob=-0.8929234743118286, - rank=2, - decoded_token=' year') - }, { 442: MockLogprob(logprob=-6.998573303222656, rank=188,