-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[python] Refactor output formatter for rolling batch #916
Conversation
f9abb7b
to
c43dab1
Compare
@@ -173,8 +173,6 @@ def initialize(self, properties: dict): | |||
self.device = int(os.getenv("LOCAL_RANK", 0)) | |||
_rolling_batch_cls = get_rolling_batch_class_from_str( | |||
self.rolling_batch_type, is_mpi, self.model_config) | |||
|
|||
# TODO: Allow user to set output formatter | |||
self.rolling_batch = _rolling_batch_cls(model_id_or_path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can make output formatter as part of the kwargs, user can pass in a function if needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed.
@@ -143,7 +143,13 @@ def _prefill_and_decode(self, new_requests): | |||
for request_id, generated_token, request in zip( | |||
request_ids, generated_tokens, self.pending_requests): | |||
is_last_token = (request_id in exit_req_ids) | |||
request.set_next_token(generated_token, last_token=is_last_token) | |||
if self.output_formatter is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
request.set_next_token(generated_token, self.output_formatter, last_token=is_last_token)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
""" | ||
|
||
self.device = device | ||
self.pending_requests = [] | ||
self.req_id_counter = 0 | ||
if 'rolling_batch_output_formatter' in kwargs: | ||
# TODO: Allow user to set custom output formatter | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.output_formatter = kwargs['rolling_batch_output_formatter']
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed
# TODO: find a way to return content-type for custom output formatter | ||
if self.output_formatter == _default_output_formatter: | ||
return "application/jsonlines" | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be plain text?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe some other types
…#916) * [python] Refactor output formatter for rolling batch * Review changes * Review changes * Review changes * Review changes
…#916) * [python] Refactor output formatter for rolling batch * Review changes * Review changes * Review changes * Review changes
Description
Brief description of what this PR is about