Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable llava static generation. #767

Merged
merged 10 commits into from
Apr 25, 2024
Prev Previous commit
Next Next commit
fix ut and code style.
  • Loading branch information
lkk12014402 committed Apr 24, 2024
commit 02ee3d46af3d077839597b5b592e94ef3fc8da4a
36 changes: 32 additions & 4 deletions examples/image-to-text/run_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
import argparse
import logging
import time

import json
import PIL.Image
import requests
import torch
from transformers import pipeline

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from pathlib import Path


logging.basicConfig(
Expand Down Expand Up @@ -67,12 +68,23 @@ def main():
action="store_true",
help="Whether to perform generation in bf16 precision.",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
help="Output directory to store results in.",
)
parser.add_argument("--batch_size", type=int, default=1, help="Input batch size.")
parser.add_argument("--warmup", type=int, default=3, help="Number of warmup iterations for benchmarking.")
parser.add_argument("--n_iterations", type=int, default=5, help="Number of inference iterations for benchmarking.")
args = parser.parse_args()

adapt_transformers_to_gaudi()
if args.image_path is None and "llava" in args.model_name_or_path:
args.image_path = ["https://llava-vl.github.io/static/images/view.jpg"]
if args.prompt is None and "llava" in args.model_name_or_path:
args.prompt = "<image>\nUSER: What's the content of the image?\nASSISTANT:"

image_paths = args.image_path
image_paths_len = len(image_paths)

Expand Down Expand Up @@ -115,11 +127,27 @@ def main():
for i in range(args.warmup):
generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs)

start = time.time()
start = time.perf_counter()
for i in range(args.n_iterations):
result = generator(images, prompt=args.prompt, batch_size=args.batch_size, generate_kwargs=generate_kwargs)
end = time.time()
logger.info(f"result = {result}, time = {(end-start) * 1000 / args.n_iterations }ms")
end = time.perf_counter()
duration = end - start

total_new_tokens_generated = args.n_iterations * args.batch_size * args.max_new_tokens
throughput = total_new_tokens_generated / duration
logger.info(f"result = {result}, time = {(end-start) * 1000 / args.n_iterations }ms, Throughput (including tokenization) = {throughput} tokens/second")

# Store results if necessary
if args.output_dir is not None:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

results = {
"throughput": throughput,
"output": result,
}
with (output_dir / "results.json").open("w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
GaudiLlamaMLP,
GaudiLlamaModel,
GaudiLlamaRotaryEmbedding,
GaudiLlavaForConditionalGeneration,
GaudiMistralAttention,
GaudiMistralDecoderLayer,
GaudiMistralForCausalLM,
Expand All @@ -55,7 +56,6 @@
GaudiOPTForCausalLM,
GaudiOPTLearnedPositionalEmbedding,
GaudiPhiForCausalLM,
GaudiLlavaForConditionalGeneration,
_gaudi_wav2vec2_compute_mask_indices,
_gaudi_wav2vec2_mask_hidden_states,
gaudi_albert_forward,
Expand Down
2 changes: 1 addition & 1 deletion optimum/habana/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
GaudiLlamaRotaryEmbedding,
gaudi_llama_rmsnorm_forward,
)
from .llava import GaudiLlavaForConditionalGeneration
from .mistral import (
GaudiMistralAttention,
GaudiMistralDecoderLayer,
Expand Down Expand Up @@ -150,4 +151,3 @@
gaudi_wav2vec2_tdnnlayer_forward,
gaudi_wav2vec2forctc_forward,
)
from .llava import GaudiLlavaForConditionalGeneration
69 changes: 35 additions & 34 deletions optimum/habana/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@
# limitations under the License.
"""PyTorch Llava model."""

import torch
from typing import List, Optional, Tuple, Union

import torch
from transformers.cache_utils import Cache
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration, LlavaCausalLMOutputWithPast
from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast, LlavaForConditionalGeneration
from transformers.utils import logging


logger = logging.get_logger(__name__)

def _pad_inputs(input_ids, attention_mask, image_token_index, num_patches,
pad_token_id, vision_feature_select_strategy=None):

def _pad_inputs(
input_ids, attention_mask, image_token_index, num_patches, pad_token_id, vision_feature_select_strategy=None
):
"""
pad inputs for static shape
"""
Expand All @@ -38,17 +42,13 @@ def _pad_inputs(input_ids, attention_mask, image_token_index, num_patches,
elif vision_feature_select_strategy == "full":
num_patches = num_patches + 1
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
raise ValueError(f"Unexpected select feature strategy: {vision_feature_select_strategy}")
image_offset = 0
new_input_ids = []
new_attention_mask = []
tokens_pos = []
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask):
num_images = (cur_input_ids == image_token_index).sum()
image_token_indices = torch.where(cur_input_ids == image_token_index)[0].tolist() + \
[cur_input_ids.shape[0]]
image_token_indices = torch.where(cur_input_ids == image_token_index)[0].tolist() + [cur_input_ids.shape[0]]

cur_input_ids_extend = []
cur_attention_mask_extend = []
Expand Down Expand Up @@ -167,7 +167,8 @@ def forward(

image_features = self.multi_modal_projector(selected_image_feature)
inputs_embeds = _merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids, self.config.image_token_index)
image_features, inputs_embeds, input_ids, self.config.image_token_index
)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down Expand Up @@ -204,19 +205,20 @@ def forward(

else:
return super().forward(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict)
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
vision_feature_layer=vision_feature_layer,
vision_feature_select_strategy=vision_feature_select_strategy,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, pixel_values=None, attention_mask=None, **kwargs
Expand All @@ -234,24 +236,23 @@ def prepare_inputs_for_generation(
image_offset = 0
tokens_pos = None
if token_idx is not None and pixel_values is not None:
input_ids, attention_mask, image_offset, tokens_pos = \
_pad_inputs(input_ids,
attention_mask,
self.config.image_token_index,
self.vision_tower.vision_model.embeddings.num_patches,
self.pad_token_id,
vision_feature_select_strategy=self.config.vision_feature_select_strategy)
input_ids, attention_mask, image_offset, tokens_pos = _pad_inputs(
input_ids,
attention_mask,
self.config.image_token_index,
self.vision_tower.vision_model.embeddings.num_patches,
self.pad_token_id,
vision_feature_select_strategy=self.config.vision_feature_select_strategy,
)

past_length = 0
if past_key_values is not None:
if token_idx is None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
Expand Down Expand Up @@ -296,7 +297,7 @@ def prepare_inputs_for_generation(
"pixel_values": pixel_values,
"token_idx": token_idx,
"image_offset": image_offset,
"tokens_pos": tokens_pos
"tokens_pos": tokens_pos,
}
)

Expand Down
27 changes: 4 additions & 23 deletions tests/test_image_to_text_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,14 @@
# Gaudi2 CI baselines
MODELS_TO_TEST = {
"bf16": [
("llava-hf/llava-1.5-7b-hf", 1, False, 264.1947269439697),
("llava-hf/llava-1.5-7b-hf", 1, 70.42849862047758),
],
}

def _test_image_to_text(
model_name: str,
baseline: float,
token: str,
batch_size: int = 1,
reuse_cache: bool = False,
deepspeed: bool = False,
world_size: int = 8,
torch_compile: bool = False,
fp8: bool = False,
):
command = ["python3"]
path_to_example_dir = Path(__file__).resolve().parent.parent / "examples"
Expand All @@ -37,7 +31,6 @@ def _test_image_to_text(
f"{path_to_example_dir / 'image-to-text' / 'run_pipeline.py'}",
f"--model_name_or_path {model_name}",
f"--batch_size {batch_size}",
"--use_kv_cache",
"--max_new_tokens 20",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you ran the test with
GAUDI2_CI=1 RUN_SLOW=true python -m pytest tests/test_image_to_text_example.py -v -s
if so, you will see run_pipeline.py: error: unrecognized arguments: --use_kv_cache --output_dir /tmp/tmpsp9f6li_ --token None
you should include whatever arguments as python3 run_pipeline.py
--model_name_or_path "llava-hf/llava-1.5-7b-hf"
--image_path "https://llava-vl.github.io/static/images/view.jpg"
--prompt "\nUSER: What's the content of the image?\nASSISTANT:"
--max_new_tokens 20
--use_hpu_graphs
--bf16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

]

Expand All @@ -51,21 +44,9 @@ def _test_image_to_text(
command.append(f"--output_dir {tmp_dir}")
print(f"\n\nCommand to test: {' '.join(command)}\n")

command.append(f"--token {token.value}")

pattern = re.compile(r"([\"\'].+?[\"\'])|\s")
command = [x for y in command for x in re.split(pattern, y) if x]

if fp8:
env_variables["QUANT_CONFIG"] = os.path.join(
path_to_example_dir, "text-generation/quantization_config/maxabs_measure_include_outputs.json"
)
subprocess.run(command, env=env_variables)
env_variables["QUANT_CONFIG"] = os.path.join(
path_to_example_dir, "text-generation/quantization_config/maxabs_quant.json"
)
command.insert(-2, "--fp8")

proc = subprocess.run(command, env=env_variables)

# Ensure the run finished without any issue
Expand All @@ -84,6 +65,6 @@ def _test_image_to_text(
assert results["throughput"] >= (2 - TIME_PERF_FACTOR) * baseline


@pytest.mark.parametrize("model_name, batch_size, reuse_cache, baseline", MODELS_TO_TEST["bf16"])
def test_text_generation_bf16(model_name: str, baseline: float, batch_size: int, reuse_cache: bool, token: str):
_test_image_to_text(model_name, baseline, token, batch_size, reuse_cache)
@pytest.mark.parametrize("model_name, batch_size, baseline", MODELS_TO_TEST["bf16"])
def test_image_to_text_bf16(model_name: str, baseline: float, batch_size: int):
_test_image_to_text(model_name, baseline, batch_size)
Loading