Skip to content

Commit

Permalink
Fix intolerance for new field in TGI stream response: 'index' (#2006)
Browse files Browse the repository at this point in the history
* Fixing intolerance for new TGI field 'index'

* make style

* fix test python38

---------

Co-authored-by: Lucain Pouget <lucainp@gmail.com>
  • Loading branch information
danielpcox and Wauplin authored Feb 6, 2024
1 parent 244e3ef commit 5433ea9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/huggingface_hub/inference/_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ class TextGenerationStreamResponse:
Args:
token (`Token`):
The generated token.
index (`Optional[int]`, *optional*):
The token index within the stream. Optional to support older clients that omit it.
generated_text (`Optional[str]`, *optional*):
The complete generated text. Only available when the generation is finished.
details (`Optional[StreamDetails]`, *optional*):
Expand All @@ -459,6 +461,9 @@ class TextGenerationStreamResponse:

# Generated token
token: Token
# The token index within the stream
# Optional to support older clients that omit it.
index: Optional[int] = None
# Complete generated text
# Only available when the generation is finished
generated_text: Optional[str] = None
Expand Down
30 changes: 30 additions & 0 deletions tests/test_inference_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
OverloadedError,
TextGenerationParameters,
TextGenerationRequest,
TextGenerationStreamResponse,
raise_text_generation_error,
)
from huggingface_hub.inference._text_generation import (
Expand Down Expand Up @@ -106,6 +107,35 @@ def test_request_validation(self):
inputs="test", parameters=TextGenerationParameters(best_of=2, do_sample=True), stream=True
)

def test_streaming_response_validation(self):
"""
Regression test for #2005.
See https://github.com/huggingface/huggingface_hub/issues/2005
"""
json_payload_latest = {
"index": 97,
"token": {"id": 264, "text": " a", "logprob": -0.0003259182, "special": False},
"generated_text": None,
"details": None,
}
TextGenerationStreamResponse(**json_payload_latest)

json_payload_older_had_no_index = {
"token": {"id": 264, "text": " a", "logprob": -0.0003259182, "special": False},
"generated_text": None,
"details": None,
}
TextGenerationStreamResponse(**json_payload_older_had_no_index)

json_payload_without_required_field_token = {
"index": 97,
"generated_text": None,
"details": None,
}
with self.assertRaises((ValidationError, TypeError)):
TextGenerationStreamResponse(**json_payload_without_required_field_token)


class TestTextGenerationErrors(unittest.TestCase):
def test_generation_error(self):
Expand Down

0 comments on commit 5433ea9

Please sign in to comment.