-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enable llava static generation. #767
enable llava static generation. #767
Conversation
based on the image-to-text generation pr #738 I test it on single card Gaudi2 with the --use_hpu_graphs:
result = [[{'generated_text': "[\nUSER: What's the content of the image?\nASSISTANT: The image features a pier extending out into a large body of water, likely a lake.\n\n"}]], time = 264.1947269439697ms Input/outputs: |
Input/outputs 1: Input/outputs 2: Input/outputs 3: Input/outputs 4: Input/outputs: Number of HPU graphs = 26 |
Just want to let you know this works like a charm! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lkk12014402 , could you please provide a brief description of the changes needed in optimum/habana/transformers/models/llava/modeling_llava.py
wrt the base model in transformers
I see a couple of single input where
, which are usually dynamic on HPU. If these are on CPU, then its fine, but if these are on HPU, they might need rewriting.
batch_indices, image_indices = torch.where(input_ids == image_token_index) |
image_token_indices = torch.where(cur_input_ids == image_token_index)[0].tolist() + \ |
hi, @ssarkar2 I will give a description and check the operation torch.where() as soon as possible |
hi, @ssarkar2 , DescriptionLet's assume the input text is generation with huggingface transformers directlythe huggingface transformers will get the text embedding [1, 4, 4096] with llava-1.5-7b-hf, and get image embedding [1,576, 4096]. Then the 2 embeddings will be merged to final input embedding [1, 579, 4096] using here. The merge function also has many dynamic op, like torch.where and the input shape is dynamic during the generation. So when we use gaudi2 to do generation, there are 2 problems:
note: to reproduce, you can use this pr image-to-text example
my optimizationIn order to maintain the transformers usage (same input, same generation script) and enable static shape by padding and inserting token_idx for generation, I add a new function And for keeping same input shape during generation, I also use token_idx. So I create 2 auxiliary variables, the explanation of maintaining
|
@ssarkar2 please help review~ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lkk12014402 can you add a ci test case and rebase?
@libinta I will update the pr with your comments soon. |
d44c540
to
1a1ee0b
Compare
hi, @libinta I have resolved the conflicting files. And I haven't seen image-to-text example test case like |
@lkk12014402 can you add a file like test_image2text_generation_example.py to include image2text generation Line 76 in 081130d |
hi @libinta please help review/check the image to text ut. Thanks~ |
tests/test_image_to_text_example.py
Outdated
|
||
|
||
@pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["bf16"]) | ||
def test_text_generation_bf16(model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
better to have image_to_test rather than text_generation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
f"--model_name_or_path {model_name}", | ||
f"--batch_size {batch_size}", | ||
"--use_kv_cache", | ||
"--max_new_tokens 20", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you ran the test with
GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/test_image_to_text_example.py -v -s
if so, you will see run_pipeline.py: error: unrecognized arguments: --use_kv_cache --output_dir /tmp/tmpsp9f6li_ --token None
you should include whatever arguments as python3 run_pipeline.py
--model_name_or_path "llava-hf/llava-1.5-7b-hf"
--image_path "https://llava-vl.github.io/static/images/view.jpg"
--prompt "\nUSER: What's the content of the image?\nASSISTANT:"
--max_new_tokens 20
--use_hpu_graphs
--bf16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
tests/test_image_to_text_example.py
Outdated
pattern = re.compile(r"([\"\'].+?[\"\'])|\s") | ||
command = [x for y in command for x in re.split(pattern, y) if x] | ||
|
||
if fp8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove fp8 section for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems there are some merge conflicts to solve, can you update your main branch and merge it into this one?
Also, please run
pip install -U ruff
make style
to have the code style check pass.
update code style with the command |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
hi, @regisss updated code with your comments.
please review~ Thanks~ |
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
What does this PR do?
support llava image to text generation