Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unexpected output for HF model with matching #294

Closed
Borda opened this issue Sep 24, 2024 · 5 comments · Fixed by #300
Closed

unexpected output for HF model with matching #294

Borda opened this issue Sep 24, 2024 · 5 comments · Fixed by #300
Assignees
Labels
bug Something isn't working good first issue Good for newcomers help wanted Extra attention is needed

Comments

@Borda
Copy link
Member

Borda commented Sep 24, 2024

🐛 Bug

To Reproduce

without batching all works as expected

{'output': 'What is the capital of Greece?\n\nAthens.'}

but with batch, it returns just the first character

{'output': 'W'}

Code sample

import torch
import litserve as ls
from transformers import AutoTokenizer, AutoModelForCausalLM


class JambaLitAPI(ls.LitAPI):

    def __init__(
        self,
        model_name: str = "ai21labs/AI21-Jamba-1.5-Mini",
        max_new_tokens: int = 100
    ):
        self.model_name = model_name
        self.max_new_tokens = max_new_tokens

    def setup(self, device):
        # Load the model and tokenizer from Hugging Face Hub
        # For example, using the `distilbert-base-uncased-finetuned-sst-2-english` model for sentiment analysis
        # You can replace the model name with any model from the Hugging Face Hub
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name, torch_dtype=torch.bfloat16, device_map="auto", use_mamba_kernels=False
        ).eval()
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, legacy=False)

    def decode_request(self, request):
        # Extract text from request
        # This assumes the request payload is of the form: {'input': 'Your input text here'}
        return request["input"]

    def predict(self, text):
        print(text)
        # Use the loaded pipeline to perform inference
        inputs = self.tokenizer(text, return_tensors='pt')
        input_ids = inputs.to(self.model.device)["input_ids"]
        print(input_ids)
        output_ids = self.model.generate(
            input_ids,
            max_new_tokens=self.max_new_tokens
        )[0]
        print(output_ids)
        text = self.tokenizer.decode(output_ids, skip_special_tokens=True)
        print(text)
        return text

    def encode_response(self, output):
        # Format the output from the model to send as a response
        # This example sends back the label and score of the prediction
        return {"output": output}


if __name__ == "__main__":
    # Create an instance of your API
    api = JambaLitAPI()
    # Start the server, specifying the port
    server = ls.LitServer(api, accelerator="cuda", devices=1, max_batch_size=4)
    # print("run the server...")
    server.run(port=8000)

Expected behavior

Environment

If you published a Studio with your bug report, we can automatically get this information. Otherwise, please describe:

  • PyTorch/Jax/Tensorflow Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@Borda Borda added bug Something isn't working help wanted Extra attention is needed labels Sep 24, 2024
@aniketmaurya
Copy link
Collaborator

context for the issue:

With batching enabled, we expect the predict method to return a list of predictions. And when user don't implement LitAPI.unbatch we wrap the output to list(output) before sending to the endode_response method.

The list(prediction_output) in this case was a string which got split by character. So, we need to warn the users in this case.

@aniketmaurya aniketmaurya added the good first issue Good for newcomers label Sep 24, 2024
@grumpyp
Copy link
Contributor

grumpyp commented Sep 24, 2024

I'd like to fix that. @aniketmaurya

Will raise a PR. Thanks.

@aniketmaurya
Copy link
Collaborator

@grumpyp looking forward! pls let me know if you have any question

@grumpyp
Copy link
Contributor

grumpyp commented Sep 25, 2024

context for the issue:

With batching enabled, we expect the predict method to return a list of predictions. And when user don't implement LitAPI.unbatch we wrap the output to list(output) before sending to the endode_response method.

The list(prediction_output) in this case was a string which got split by character. So, we need to warn the users in this case.

Would you want to prevent this to happen, or can you think of cases where this is needed - so I'll only add a warning in case output is a string when batching is enabled? I could additionally introcude a parameter to enforce list-like outputs.

@aniketmaurya
Copy link
Collaborator

aniketmaurya commented Sep 25, 2024

@grumpyp let's just print a warning for now and observe any new issue on this. You can add the logic here.

aniketmaurya added a commit that referenced this issue Oct 19, 2024
* fix: add warning unexpected output from HF model (closes #294)

* add: warning if batched loop return string

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move output check to unbatch_no_stream

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* restore format loops.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: lint E501 Line too long

* Update tests/test_batch.py

Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>

* Update src/litserve/api.py

Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>

* Update test to match new warning string

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Delete whitespace in warning string test_batch

* Update test_batch.py

* Update test_batch.py

* Update warning copy

* update test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix test

* fix Pyright issues in README examples, init parameter naming in LitAPI

* remove test from outdated

* cleanup

* add Any return type

* revert

* add host parameter to LitServer run method

* revert change in api.py

* add: tests and validation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants