From 542229abb3aba2032d4c52a878c0fd35ba299691 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Mon, 20 May 2024 23:36:43 +0800 Subject: [PATCH] fix paligemma inference --- src/llamafactory/chat/hf_engine.py | 25 ++++++++++++++++++------ src/llamafactory/chat/vllm_engine.py | 29 +++++++++++++++++----------- src/llamafactory/data/template.py | 11 ++++++++--- src/llamafactory/extras/constants.py | 2 ++ 4 files changed, 47 insertions(+), 20 deletions(-) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 57cdc89adf..f59029a153 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -8,6 +8,7 @@ from transformers import GenerationConfig, TextIteratorStreamer from ..data import get_template_and_fix_tokenizer +from ..extras.constants import IMAGE_TOKEN from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer from .base_engine import BaseEngine, Response @@ -55,14 +56,28 @@ def _process_args( image: Optional["NDArray"] = None, input_kwargs: Optional[Dict[str, Any]] = {}, ) -> Tuple[Dict[str, Any], int]: - if processor is not None and image is not None and "" not in messages[0]["content"]: - messages[0]["content"] = "" + messages[0]["content"] + if ( + processor is not None + and image is not None + and not hasattr(processor, "image_seq_length") + and IMAGE_TOKEN not in messages[0]["content"] + ): # llava case + messages[0]["content"] = IMAGE_TOKEN + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or generating_args["default_system"] + pixel_values = None prompt_ids, _ = template.encode_oneturn( tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools ) + if processor is not None and image is not None: # add image features + image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") + batch_feature = image_processor(image, return_tensors="pt") + pixel_values = batch_feature.to(model.device)["pixel_values"] # shape (B, C, H, W) + if hasattr(processor, "image_seq_length"): # paligemma case + image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) + prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids + prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) @@ -122,10 +137,8 @@ def _process_args( logits_processor=get_logits_processor(), ) - if processor is not None and image is not None: - image_processor: "BaseImageProcessor" = getattr(processor, "image_processor") - pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"] - gen_kwargs["pixel_values"] = pixel_values.to(model.device) + if pixel_values is not None: + gen_kwargs["pixel_values"] = pixel_values return gen_kwargs, prompt_length diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 44b9651f73..31d03fbeb8 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from ..data import get_template_and_fix_tokenizer +from ..extras.constants import IMAGE_TOKEN from ..extras.logging import get_logger from ..extras.misc import get_device_count, infer_optim_dtype from ..extras.packages import is_vllm_available @@ -17,7 +18,6 @@ if TYPE_CHECKING: - import torch from numpy.typing import NDArray from transformers.image_processing_utils import BaseImageProcessor @@ -67,7 +67,7 @@ def __init__( patch_size = config.vision_config.patch_size self.image_feature_size = (image_size // patch_size) ** 2 engine_args["image_input_type"] = "pixel_values" - engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids("") + engine_args["image_token_id"] = self.tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) engine_args["image_input_shape"] = "1,3,{},{}".format(image_size, image_size) engine_args["image_feature_size"] = self.image_feature_size if getattr(config, "is_yi_vl_derived_model", None): @@ -92,14 +92,28 @@ async def _generate( **input_kwargs, ) -> AsyncIterator["RequestOutput"]: request_id = "chatcmpl-{}".format(uuid.uuid4().hex) - if self.processor is not None and image is not None and "" not in messages[0]["content"]: - messages[0]["content"] = "" * self.image_feature_size + messages[0]["content"] + + if ( + self.processor is not None + and image is not None + and not hasattr(self.processor, "image_seq_length") + and IMAGE_TOKEN not in messages[0]["content"] + ): # llava case + messages[0]["content"] = IMAGE_TOKEN * self.image_feature_size + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools ) + + if self.processor is not None and image is not None: # add image features + image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") + pixel_values = image_processor(image, return_tensors="pt")["pixel_values"] + multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) + else: + multi_modal_data = None + prompt_length = len(prompt_ids) use_beam_search: bool = self.generating_args["num_beams"] > 1 @@ -144,13 +158,6 @@ async def _generate( skip_special_tokens=True, ) - if self.processor is not None and image is not None: - image_processor: "BaseImageProcessor" = getattr(self.processor, "image_processor") - pixel_values: "torch.Tensor" = image_processor(image, return_tensors="pt")["pixel_values"] - multi_modal_data = MultiModalData(type=MultiModalData.Type.IMAGE, data=pixel_values) - else: - multi_modal_data = None - result_generator = self.model.generate( prompt=None, sampling_params=sampling_params, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 66e9dca5f9..bf7133a963 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -290,10 +290,10 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl slot_items.append(placeholder) if slot_pieces[1]: slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") - elif isinstance(slot, set): - if "bos_token" in slot: + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: slot_items.append("'" + tokenizer.bos_token + "'") - elif "eos_token" in slot: # do not use {{ eos_token }} since it may be replaced + elif "eos_token" in slot and tokenizer.eos_token_id is not None: slot_items.append("'" + tokenizer.eos_token + "'") elif isinstance(slot, dict): raise ValueError("Dict is not supported.") @@ -325,9 +325,11 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" jinja_template += "{% set content = " + system_message + " + message['content'] %}" jinja_template += "{% endif %}" + jinja_template += "{% if message['role'] == 'user' %}" user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) jinja_template += "{{ " + user_message + " }}" + jinja_template += "{% elif message['role'] == 'assistant' %}" assistant_message = _convert_slots_to_jinja( template.format_assistant.apply() + template.format_separator.apply(), tokenizer @@ -614,6 +616,9 @@ def get_template_and_fix_tokenizer( name="empty", format_user=StringFormatter(slots=["{{content}}"]), format_assistant=StringFormatter(slots=["{{content}}"]), + format_system=StringFormatter(slots=[{"bos_token"}, "{{content}}"]), + efficient_eos=True, + force_system=True, ) diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index fecf0c38f7..cf31ad6683 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -22,6 +22,8 @@ IGNORE_INDEX = -100 +IMAGE_TOKEN = "" + LAYERNORM_NAMES = {"norm", "ln"} METHODS = ["full", "freeze", "lora"]