Skip to content

Commit

Permalink
make jsonline outputs generated tokens (#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qing Lan authored Jan 5, 2024
1 parent 51b96d6 commit 6fcda33
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 10 deletions.
29 changes: 19 additions & 10 deletions engines/python/setup/djl_python/rolling_batch/rolling_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def as_dict(self):


def _json_output_formatter(token: Token, first_token: bool, last_token: bool,
details: dict):
details: dict, generated_tokens: str):
"""
json output formatter
Expand All @@ -74,18 +74,21 @@ def _json_output_formatter(token: Token, first_token: bool, last_token: bool,


def _jsonlines_output_formatter(token: Token, first_token: bool,
last_token: bool, details: dict):
last_token: bool, details: dict,
generated_tokens: str):
"""
jsonlines output formatter
:return: formatted output
"""
token_dict = token.as_dict()
final_dict = {"token": token_dict}
if last_token and details:
final_dict["details"] = {
"finish_reason": details.get("finish_reason", None)
}
if last_token:
final_dict["generated_text"] = generated_tokens
if details:
final_dict["details"] = {
"finish_reason": details.get("finish_reason", None)
}
json_encoded_str = json.dumps(final_dict, ensure_ascii=False) + "\n"
return json_encoded_str

Expand Down Expand Up @@ -115,6 +118,7 @@ def __init__(self, id: int, input_text: str, parameters: dict):
self.first_token = True
self.last_token = False
self.token_cache = None
self.generated_tokens = []
if parameters.pop("details", False):
self.token_cache = []

Expand All @@ -141,16 +145,21 @@ def set_next_token(self,
next_token = Token(-1, next_token)
if self.token_cache is not None:
self.token_cache.append(next_token.as_dict())
self.generated_tokens.append(next_token.text)
details = {}
if last_token and self.token_cache is not None:
details["finish_reason"] = finish_reason
details["tokens"] = self.token_cache
generated_text = None
if last_token:
generated_text = ''.join(self.generated_tokens)
if self.token_cache is not None:
details["finish_reason"] = finish_reason
details["tokens"] = self.token_cache
if output_formatter is None:
self.next_token_str = next_token.text
else: # output only supports size one now
self.next_token_str = output_formatter(next_token,
self.first_token,
last_token, details)
last_token, details,
generated_text)
self.last_token = last_token
self.first_token = False

Expand Down
118 changes: 118 additions & 0 deletions engines/python/setup/djl_python/tests/test_rolling_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
import unittest
from djl_python.rolling_batch.rolling_batch import Request, Token, _json_output_formatter, _jsonlines_output_formatter


class TestRollingBatch(unittest.TestCase):

def test_json_fmt(self):
req = Request(0, "This is a wonderful day", {"max_new_tokens": 256})
req.set_next_token(Token(244, "He", -0.334532), _json_output_formatter)
print(req.get_next_token(), end='')
assert req.get_next_token() == '{"generated_text": "He'
req.set_next_token(Token(576, "llo", -0.123123),
_json_output_formatter)
print(req.get_next_token(), end='')
assert req.get_next_token() == 'llo'
req.set_next_token(Token(4558, " world", -0.567854),
_json_output_formatter, True, 'length')
print(req.get_next_token(), end='')
assert req.get_next_token() == ' world"}'

def test_jsonlines_fmt(self):
req = Request(0, "This is a wonderful day", {"max_new_tokens": 256})
req.set_next_token(Token(244, "He", -0.334532),
_jsonlines_output_formatter)
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
"token": {
"id": 244,
"text": "He",
"log_prob": -0.334532
}
}
req.set_next_token(Token(576, "llo", -0.123123),
_jsonlines_output_formatter)
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
"token": {
"id": 576,
"text": "llo",
"log_prob": -0.123123
}
}
req.set_next_token(Token(4558, " world", -0.567854),
_jsonlines_output_formatter, True, 'length')
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
"token": {
"id": 4558,
"text": " world",
"log_prob": -0.567854
},
"generated_text": "Hello world"
}

def test_details(self):
req = Request(0, "This is a wonderful day", {
"max_new_tokens": 256,
"details": True
})
final_str = []
req.set_next_token(Token(244, "He", -0.334532), _json_output_formatter)
final_str.append(req.get_next_token())
req.set_next_token(Token(576, "llo", -0.123123),
_json_output_formatter)
final_str.append(req.get_next_token())
req.set_next_token(Token(4558, " world", -0.567854),
_json_output_formatter, True, 'length')
final_str.append(req.get_next_token())
final_json = json.loads(''.join(final_str))
print(final_json)
assert final_json == {
"generated_text": "Hello world",
"details": {
"finish_reason":
"length",
"tokens": [{
"id": 244,
"text": "He",
"log_prob": -0.334532
}, {
"id": 576,
"text": "llo",
"log_prob": -0.123123
}, {
"id": 4558,
"text": " world",
"log_prob": -0.567854
}]
}
}
# Jsonlines tests
req = Request(0, "This is a wonderful day", {
"max_new_tokens": 256,
"details": True
})
req.set_next_token(Token(244, "He", -0.334532),
_jsonlines_output_formatter)
req.set_next_token(Token(576, "llo", -0.123123),
_jsonlines_output_formatter)
req.set_next_token(Token(4558, " world", -0.567854),
_jsonlines_output_formatter, True, 'length')
print(req.get_next_token(), end='')
assert json.loads(req.get_next_token()) == {
"token": {
"id": 4558,
"text": " world",
"log_prob": -0.567854
},
"generated_text": "Hello world",
"details": {
"finish_reason": "length"
}
}


if __name__ == '__main__':
unittest.main()

0 comments on commit 6fcda33

Please sign in to comment.