diff --git a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py index 51f0908d9..651d375d6 100644 --- a/engines/python/setup/djl_python/rolling_batch/rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/rolling_batch.py @@ -55,7 +55,7 @@ def as_dict(self): def _json_output_formatter(token: Token, first_token: bool, last_token: bool, - details: dict): + details: dict, generated_tokens: str): """ json output formatter @@ -74,7 +74,8 @@ def _json_output_formatter(token: Token, first_token: bool, last_token: bool, def _jsonlines_output_formatter(token: Token, first_token: bool, - last_token: bool, details: dict): + last_token: bool, details: dict, + generated_tokens: str): """ jsonlines output formatter @@ -82,10 +83,12 @@ def _jsonlines_output_formatter(token: Token, first_token: bool, """ token_dict = token.as_dict() final_dict = {"token": token_dict} - if last_token and details: - final_dict["details"] = { - "finish_reason": details.get("finish_reason", None) - } + if last_token: + final_dict["generated_text"] = generated_tokens + if details: + final_dict["details"] = { + "finish_reason": details.get("finish_reason", None) + } json_encoded_str = json.dumps(final_dict, ensure_ascii=False) + "\n" return json_encoded_str @@ -115,6 +118,7 @@ def __init__(self, id: int, input_text: str, parameters: dict): self.first_token = True self.last_token = False self.token_cache = None + self.generated_tokens = [] if parameters.pop("details", False): self.token_cache = [] @@ -141,16 +145,21 @@ def set_next_token(self, next_token = Token(-1, next_token) if self.token_cache is not None: self.token_cache.append(next_token.as_dict()) + self.generated_tokens.append(next_token.text) details = {} - if last_token and self.token_cache is not None: - details["finish_reason"] = finish_reason - details["tokens"] = self.token_cache + generated_text = None + if last_token: + generated_text = ''.join(self.generated_tokens) + if self.token_cache is not None: + details["finish_reason"] = finish_reason + details["tokens"] = self.token_cache if output_formatter is None: self.next_token_str = next_token.text else: # output only supports size one now self.next_token_str = output_formatter(next_token, self.first_token, - last_token, details) + last_token, details, + generated_text) self.last_token = last_token self.first_token = False diff --git a/engines/python/setup/djl_python/tests/test_rolling_batch.py b/engines/python/setup/djl_python/tests/test_rolling_batch.py new file mode 100644 index 000000000..76bad00a5 --- /dev/null +++ b/engines/python/setup/djl_python/tests/test_rolling_batch.py @@ -0,0 +1,118 @@ +import json +import unittest +from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter + + +class TestRollingBatch(unittest.TestCase): + + def test_json_fmt(self): + req = Request(0, "This is a wonderful day", {"max_new_tokens": 256}) + req.set_next_token(Token(244, "He", -0.334532), _json_output_formatter) + print(req.get_next_token(), end='') + assert req.get_next_token() == '{"generated_text": "He' + req.set_next_token(Token(576, "llo", -0.123123), + _json_output_formatter) + print(req.get_next_token(), end='') + assert req.get_next_token() == 'llo' + req.set_next_token(Token(4558, " world", -0.567854), + _json_output_formatter, True, 'length') + print(req.get_next_token(), end='') + assert req.get_next_token() == ' world"}' + + def test_jsonlines_fmt(self): + req = Request(0, "This is a wonderful day", {"max_new_tokens": 256}) + req.set_next_token(Token(244, "He", -0.334532), + _jsonlines_output_formatter) + print(req.get_next_token(), end='') + assert json.loads(req.get_next_token()) == { + "token": { + "id": 244, + "text": "He", + "log_prob": -0.334532 + } + } + req.set_next_token(Token(576, "llo", -0.123123), + _jsonlines_output_formatter) + print(req.get_next_token(), end='') + assert json.loads(req.get_next_token()) == { + "token": { + "id": 576, + "text": "llo", + "log_prob": -0.123123 + } + } + req.set_next_token(Token(4558, " world", -0.567854), + _jsonlines_output_formatter, True, 'length') + print(req.get_next_token(), end='') + assert json.loads(req.get_next_token()) == { + "token": { + "id": 4558, + "text": " world", + "log_prob": -0.567854 + }, + "generated_text": "Hello world" + } + + def test_details(self): + req = Request(0, "This is a wonderful day", { + "max_new_tokens": 256, + "details": True + }) + final_str = [] + req.set_next_token(Token(244, "He", -0.334532), _json_output_formatter) + final_str.append(req.get_next_token()) + req.set_next_token(Token(576, "llo", -0.123123), + _json_output_formatter) + final_str.append(req.get_next_token()) + req.set_next_token(Token(4558, " world", -0.567854), + _json_output_formatter, True, 'length') + final_str.append(req.get_next_token()) + final_json = json.loads(''.join(final_str)) + print(final_json) + assert final_json == { + "generated_text": "Hello world", + "details": { + "finish_reason": + "length", + "tokens": [{ + "id": 244, + "text": "He", + "log_prob": -0.334532 + }, { + "id": 576, + "text": "llo", + "log_prob": -0.123123 + }, { + "id": 4558, + "text": " world", + "log_prob": -0.567854 + }] + } + } + # Jsonlines tests + req = Request(0, "This is a wonderful day", { + "max_new_tokens": 256, + "details": True + }) + req.set_next_token(Token(244, "He", -0.334532), + _jsonlines_output_formatter) + req.set_next_token(Token(576, "llo", -0.123123), + _jsonlines_output_formatter) + req.set_next_token(Token(4558, " world", -0.567854), + _jsonlines_output_formatter, True, 'length') + print(req.get_next_token(), end='') + assert json.loads(req.get_next_token()) == { + "token": { + "id": 4558, + "text": " world", + "log_prob": -0.567854 + }, + "generated_text": "Hello world", + "details": { + "finish_reason": "length" + } + } + + +if __name__ == '__main__': + unittest.main()