Skip to content

Commit

Permalink
Improve serving.
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Feb 3, 2024
1 parent d4c2f1f commit 806cfe0
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 68 deletions.
108 changes: 49 additions & 59 deletions llava/conversation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
import base64
from io import BytesIO
from PIL import Image


class SeparatorStyle(Enum):
Expand Down Expand Up @@ -106,79 +109,66 @@ def get_prompt(self):
def append_message(self, role, message):
self.messages.append([role, message])

def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if max(image.size) > max_len:
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
return image
else:
buffered = BytesIO()
image.save(buffered, format=image_format)
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
return img_b64_str

def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize":
image = image.resize((336, 336))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
image = self.process_image(image, image_process_mode, return_pil=return_pil)
images.append(image)
return images

def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
img_b64_str = self.process_image(
image, "Default", return_pil=False,
image_format='JPEG')
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<image>', '').strip()
ret.append([msg, None])
else:
Expand Down
16 changes: 11 additions & 5 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
kwargs = {"device_map": device_map, **kwargs}

if device != "cuda":
Expand All @@ -42,6 +42,9 @@ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, l
else:
kwargs['torch_dtype'] = torch.float16

if use_flash_attn:
kwargs['attn_implementation'] = 'flash_attention_2'

if 'llava' in model_name.lower():
# Load LLaVA model
if 'lora' in model_name.lower() and model_base is None:
Expand Down Expand Up @@ -88,7 +91,7 @@ def load_from_hf(repo_id, filename, subfolder=None):
shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
Expand All @@ -100,18 +103,21 @@ def load_from_hf(repo_id, filename, subfolder=None):
else:
if 'mpt' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
elif 'mistral' in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMistralForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
use_flash_attention_2=False,
**kwargs
)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
model = LlavaLlamaForCausalLM.from_pretrained(
model_path,
low_cpu_mem_usage=True,
**kwargs
)
else:
# Load language model
if model_base is not None:
Expand Down
8 changes: 5 additions & 3 deletions llava/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ModelWorker:
def __init__(self, controller_addr, worker_addr,
worker_id, no_register,
model_path, model_base, model_name,
load_8bit, load_4bit, device):
load_8bit, load_4bit, device, use_flash_attn=False):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
Expand All @@ -63,7 +63,7 @@ def __init__(self, controller_addr, worker_addr,
self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
self.is_multimodal = 'llava' in self.model_name.lower()

if not no_register:
Expand Down Expand Up @@ -268,6 +268,7 @@ async def get_status(request: Request):
parser.add_argument("--no-register", action="store_true")
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--use-flash-attn", action="store_true")
args = parser.parse_args()
logger.info(f"args: {args}")

Expand All @@ -283,5 +284,6 @@ async def get_status(request: Request):
args.model_name,
args.load_8bit,
args.load_4bit,
args.device)
args.device,
use_flash_attn=args.use_flash_attn)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "llava"
version = "1.2.1.post4"
version = "1.2.2"
description = "Towards GPT-4 like large language and visual assistant."
readme = "README.md"
requires-python = ">=3.8"
Expand Down

0 comments on commit 806cfe0

Please sign in to comment.