diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 69cfbf21c..8d7ab79c3 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -30,7 +30,7 @@ def _download_hf_snapshot( try: snapshot_download( model_config.distribution_path, - cache_dir=artifact_dir, + local_dir=artifact_dir, local_dir_use_symlinks=False, token=hf_token, ignore_patterns=None if "llava" in model_config.name else "*safetensors*", diff --git a/torchchat/generate.py b/torchchat/generate.py index 52f1e45fa..be1cab606 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -374,7 +374,7 @@ def prefill( logits = model(x) else: # input_pos: [B, S] - logits = model(x, input_pos=input_pos) + logits = model(x, input_pos) # print(f"logits {logits.shape}") # print(f"x: {x},\n input_pos: {input_pos}\n") @@ -398,7 +398,7 @@ def decode_one_token( else: logits = model(x) else: - logits = model(x, input_pos=input_pos) + logits = model(x, input_pos) # print(f"x: {x},\n input_pos: {input_pos}\n") return self.sample(logits, need_probs=need_probs, **sampling_kwargs) diff --git a/torchchat/model.py b/torchchat/model.py index b23084d90..e14b0d2b8 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -562,10 +562,9 @@ def __init__(self, config: ModelArgs) -> None: def forward( self, tokens: Tensor, - *, + input_pos: Optional[Tensor] = None, encoder_input: Optional[Dict[str, Tensor]] = None, post_tokens: Optional[Tensor] = None, - input_pos: Optional[Tensor] = None, ) -> Tensor: return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) @@ -1032,470 +1031,3 @@ def setup_caches(self, max_batch_size, max_seq_length): except: pass - - -from torchvision import transforms as tvT - -def llava_image_preprocess( - # img_address: str, - target_h: int, - target_w: int, - rescale_factor: float, - image_mean: List[float], - image_std: List[float], - ) -> torch.Tensor: - """ - Preprocess an image by resizing it to fit a target height and width, - padding with median RGB value to make a square, scaling, and normalizing. - - Args: - img_address (str): Address of the local image file will be forwarded to the model. - target_h (int): Target height. - target_w (int): Target width. - rescale_factor (float): Rescaling factor. - image_mean (list): Mean values for normalization. - image_std (list): Standard deviation values for normalization. - - Returns: - torch.Tensor: Preprocessed image tensor. - - Raises: - FileNotFoundError: If the image file does not exist. - ValueError: If the target height or width is not positive. - """ - - # # Check if the image file exists - # if not os.path.exists(img_address): - # raise FileNotFoundError("Image file not found") - - # Check if the target height and width are positive - if target_h <= 0 or target_w <= 0: - raise ValueError("Target height and width must be positive") - - # Load the image from the given address - image = Image.open( - requests.get( - "https://llava-vl.github.io/static/images/view.jpg", stream=True - ).raw) - # Convert the image to a tensor - img = tvT.functional.pil_to_tensor(image) - - # Calculate the height and width ratios - ratio_h = img.shape[1] / target_h - ratio_w = img.shape[2] / target_w - - # Resize the image to fit in a target_h x target_w canvas - ratio = max(ratio_h, ratio_w) - output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) - img = tvT.Resize(size=output_size)(img) - - # Pad the image with median RGB value to make a square - l_pad = (target_w - img.shape[2]) // 2 - t_pad = (target_h - img.shape[1]) // 2 - r_pad = -((target_w - img.shape[2]) // -2) - b_pad = -((target_h - img.shape[1]) // -2) - - torch._check(l_pad >= 0) - torch._check(t_pad >= 0) - torch._check(r_pad >= 0) - torch._check(b_pad >= 0) - - # Pad the image - resized = torch.nn.functional.pad( - img, - (l_pad, r_pad, t_pad, b_pad), - ) - - # Scale the image - scaled = resized * rescale_factor - - # Normalize the image - normed = tvT.Normalize(image_mean, image_std)(scaled) - - return normed.unsqueeze(0) - - - - - -if __name__ == "__main__": - import re - from PIL import Image - import requests - - def prepare_image(target_h: int, target_w: int) -> torch.Tensor: - """Read image into a tensor and resize the image so that it fits in - a target_h x target_w canvas. - - Args: - image (Image): An Image object. - target_h (int): Target height. - target_w (int): Target width. - - Returns: - torch.Tensor: resized image tensor. - """ - image = Image.open( - requests.get( - "https://llava-vl.github.io/static/images/view.jpg", stream=True - ).raw) - - img = torchvision.transforms.functional.pil_to_tensor(image) - # height ratio - ratio_h = img.shape[1] / target_h - # width ratio - ratio_w = img.shape[2] / target_w - # resize the image so that it fits in a target_h x target_w canvas - ratio = max(ratio_h, ratio_w) - output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) - img = torchvision.transforms.Resize(size=output_size)(img) - return img - - - def image_preprocess(img: torch.Tensor, target_h: int, target_w: int, rescale_factor, image_mean, image_std) -> torch.Tensor: - # pad the image with median rgb value, to make a square - l_pad = (target_w - img.shape[2]) // 2 - t_pad = (target_h - img.shape[1]) // 2 - # ceil division - r_pad = -((target_w - img.shape[2]) // -2) - b_pad = -((target_h - img.shape[1]) // -2) - - torch._check(l_pad >= 0) - torch._check(t_pad >= 0) - torch._check(r_pad >= 0) - torch._check(b_pad >= 0) - - # This is different from the original implementation, due to export limitations. - resized = torch.nn.functional.pad( - img, - (l_pad, r_pad, t_pad, b_pad), - ) - - scaled = resized * rescale_factor - from torchvision.transforms.v2 import functional as tvF - normed = tvF.normalize( - scaled, image_mean, image_std - ) - return normed.unsqueeze(0) - - - # def checkpoint_remap(llava_model, llava_ckpt): - # def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]: - # translated_state_dict = {} - - # # Define the mapping from old names to new names - # hf_weight_prefix = "vision_model." - # name_mapping = { - # f"{hf_weight_prefix}embeddings.class_embedding": "cls_token_embedding.weight", - # f"{hf_weight_prefix}embeddings.position_embedding.weight": "token_pos_embedding.positional_embedding", - # f"{hf_weight_prefix}embeddings.patch_embedding.weight": "conv.weight", - # f"{hf_weight_prefix}pre_layrnorm.weight": "ln_pre.weight", - # f"{hf_weight_prefix}pre_layrnorm.bias": "ln_pre.bias", - # f"{hf_weight_prefix}post_layernorm.weight": "ln_post.weight", - # f"{hf_weight_prefix}post_layernorm.bias": "ln_post.bias", - # } - - # # Use regular expressions to define the mapping for each layer - # patterns = [ - # ( - # rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)", - # lambda match: f"layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}", - # ), - # ( - # rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)", - # lambda match: f"layers.{match.group(1)}.attn.output_proj.{match.group(2)}", - # ), - # ( - # rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)", - # lambda match: f"layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}", - # ), - # ( - # rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)", - # lambda match: f"layers.{match.group(1)}.sa_norm.{match.group(2)}", - # ), - # ( - # rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)", - # lambda match: f"layers.{match.group(1)}.mlp_norm.{match.group(2)}", - # ), - # ] - - # # Apply the patterns to update the name mapping - # for pattern, replacement in patterns: - # for key in list(hf_state_dict.keys()): - # if re.match(pattern, key): - # new_key = re.sub(pattern, replacement, key) - # name_mapping[key] = new_key - - # # Process the combined self-attention weights and biases - # temp_state_dict = {} - # for k, v in hf_state_dict.items(): - # new_k = name_mapping[k] - # if "in_proj_weight" in new_k or "in_proj_bias" in new_k: - # if new_k not in temp_state_dict: - # temp_state_dict[new_k] = {"q": None, "k": None, "v": None} - # if "q_proj" in k: - # temp_state_dict[new_k]["q"] = v - # elif "k_proj" in k: - # temp_state_dict[new_k]["k"] = v - # elif "v_proj" in k: - # temp_state_dict[new_k]["v"] = v - # else: - # temp_state_dict[new_k] = v - - # # Final processing of the combined self-attention weights and biases - # for k, v in temp_state_dict.items(): - # if isinstance(v, dict): - # translated_state_dict[k] = torch.cat([v["q"], v["k"], v["v"]], dim=0) - # else: - # translated_state_dict[k] = v - - # return translated_state_dict - - # new_state_dict = {} - # for k, v in state_dict.items(): - # if k.startswith("model.model."): - # new_state_dict[k.replace("model.model.", "")] = v - # elif k.startswith("model."): - # new_state_dict[k.replace("model.", "")] = v - # else: - # new_state_dict[k] = v - # return new_state_dict - - # def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]: - # key_map = { - # # fmt: off - # r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", - # r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", - # r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", - # r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.", - # r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.", - # r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.", - # r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.", - # r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.", - # r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.", - # r"model.norm.": r"norm.", - # # r"model.embed_tokens.": r"tok_embeddings.", # load separately - # r"lm_head.": r"output.", - # # fmt: on - # } - - # new_state_dict = {} - - # def get_new_key(old_key: str) -> str: - # for old_pattern, replacement in key_map.items(): - # if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: - # return new_key - - # return old_key - - # # Convert module keys from hf transformer to Llama transformer. - # for old_key in hf_state_dict.keys(): - # new_key = get_new_key(old_key) - - # new_state_dict[new_key] = hf_state_dict[old_key] - - # return new_state_dict - - # def split_checkpoint(llava_ckpt): - # from collections import OrderedDict - # language_model_ckpt = OrderedDict() - # multi_modal_ckpt = OrderedDict() - # vision_tower_ckpt = OrderedDict() - # for key, value in llava_ckpt.items(): - # if key.startswith("language_model"): - # language_model_ckpt[key[len("language_model") + 1:]] = value - # elif key.startswith("multi_modal_projector"): - # multi_modal_ckpt[key[len("multi_modal_projector") + 1:]] = value - # elif key.startswith("vision_tower"): - # vision_tower_ckpt[key[len("vision_tower") + 1:]] = value - # return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt - - # llava_model = llava_model.model - - # language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt) - - # llava_model.tok_embeddings.load_state_dict({"weight": language_model_ckpt.pop("model.embed_tokens.weight")}) - - # llava_model.encoder.load_state_dict(state_dict=_translate_state_dict_for_vision_model(vision_tower_ckpt), - # strict=True, - # assign=True, - # ) - - # llava_model.decoder.load_state_dict(state_dict=_translate_state_dict_for_text_model(language_model_ckpt), - # strict=True, - # assign=True, - # ) - - # llava_model.mm_projector.load_state_dict(state_dict=multi_modal_ckpt, - # strict=True, - # assign=True, - # ) - - def remap_llava_checkpoint(llava_ckpt): - def _translate_state_dict_for_vision_model(hf_state_dict) -> Dict[str, Any]: - translated_state_dict = {} - hf_weight_prefix = "vision_model." - name_mapping = { - f"{hf_weight_prefix}embeddings.class_embedding": "model.encoder.cls_token_embedding.weight", - f"{hf_weight_prefix}embeddings.position_embedding.weight": "model.encoder.token_pos_embedding.positional_embedding", - f"{hf_weight_prefix}embeddings.patch_embedding.weight": "model.encoder.conv.weight", - f"{hf_weight_prefix}pre_layrnorm.weight": "model.encoder.ln_pre.weight", - f"{hf_weight_prefix}pre_layrnorm.bias": "model.encoder.ln_pre.bias", - f"{hf_weight_prefix}post_layernorm.weight": "model.encoder.ln_post.weight", - f"{hf_weight_prefix}post_layernorm.bias": "model.encoder.ln_post.bias", - } - patterns = [ - ( - rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.(k|q|v)_proj\.(weight|bias)", - lambda match: f"model.encoder.layers.{match.group(1)}.attn.{match.group(2)}_proj.{match.group(3)}", - ), - ( - rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.(weight|bias)", - lambda match: f"model.encoder.layers.{match.group(1)}.attn.output_proj.{match.group(2)}", - ), - ( - rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.mlp\.fc(1|2)\.(weight|bias)", - lambda match: f"model.encoder.layers.{match.group(1)}.mlp.w{match.group(2)}.{match.group(3)}", - ), - ( - rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm1\.(weight|bias)", - lambda match: f"model.encoder.layers.{match.group(1)}.sa_norm.{match.group(2)}", - ), - ( - rf"{hf_weight_prefix}encoder\.layers\.([0-9]+)\.layer_norm2\.(weight|bias)", - lambda match: f"model.encoder.layers.{match.group(1)}.mlp_norm.{match.group(2)}", - ), - ] - for pattern, replacement in patterns: - for key in list(hf_state_dict.keys()): - if re.match(pattern, key): - new_key = re.sub(pattern, replacement, key) - name_mapping[key] = new_key - temp_state_dict = {} - for k, v in hf_state_dict.items(): - new_k = name_mapping.get(k, k) - if "in_proj_weight" in new_k or "in_proj_bias" in new_k: - if new_k not in temp_state_dict: - temp_state_dict[new_k] = {"q": None, "k": None, "v": None} - if "q_proj" in k: - temp_state_dict[new_k]["q"] = v - elif "k_proj" in k: - temp_state_dict[new_k]["k"] = v - elif "v_proj" in k: - temp_state_dict[new_k]["v"] = v - else: - temp_state_dict[new_k] = v - for k, v in temp_state_dict.items(): - if isinstance(v, dict): - translated_state_dict[k] = torch.cat([v["q"], v["k"], v["v"]], dim=0) - else: - translated_state_dict[k] = v - return translated_state_dict - - def _translate_state_dict_for_text_model(hf_state_dict) -> Dict[str, Any]: - key_map = { - r"model.layers.([0-9]+).self_attn.q_proj.": r"model.decoder.layers.\1.attention.wq.", - r"model.layers.([0-9]+).self_attn.k_proj.": r"model.decoder.layers.\1.attention.wk.", - r"model.layers.([0-9]+).self_attn.v_proj.": r"model.decoder.layers.\1.attention.wv.", - r"model.layers.([0-9]+).self_attn.o_proj.": r"model.decoder.layers.\1.attention.wo.", - r"model.layers.([0-9]+).input_layernorm.": r"model.decoder.layers.\1.attention_norm.", - r"model.layers.([0-9]+).mlp.gate_proj.": r"model.decoder.layers.\1.feed_forward.w1.", - r"model.layers.([0-9]+).mlp.down_proj.": r"model.decoder.layers.\1.feed_forward.w2.", - r"model.layers.([0-9]+).mlp.up_proj.": r"model.decoder.layers.\1.feed_forward.w3.", - r"model.layers.([0-9]+).post_attention_layernorm.": r"model.decoder.layers.\1.ffn_norm.", - r"model.norm.": r"model.decoder.norm.", - # r"model.embed_tokens.": r"tok_embeddings.", # load separately - r"lm_head.": r"model.decoder.output.", - } - new_state_dict = {} - def get_new_key(old_key: str) -> str: - for old_pattern, replacement in key_map.items(): - if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: - return new_key - return old_key - for old_key in hf_state_dict.keys(): - new_key = get_new_key(old_key) - new_state_dict[new_key] = hf_state_dict[old_key] - return new_state_dict - - def _translate_state_dict_for_mm_projector_model(hf_state_dict) -> Dict[str, Any]: - new_state_dict = {} - for old_key in hf_state_dict.keys(): - new_key = "model.mm_projector." + old_key - new_state_dict[new_key] = hf_state_dict[old_key] - return new_state_dict - - def split_checkpoint(llava_ckpt): - language_model_ckpt = {} - multi_modal_ckpt = {} - vision_tower_ckpt = {} - for key, value in llava_ckpt.items(): - if key.startswith("language_model"): - language_model_ckpt[key[len("language_model") + 1:]] = value - elif key.startswith("multi_modal_projector"): - multi_modal_ckpt[key[len("multi_modal_projector") + 1:]] = value - elif key.startswith("vision_tower"): - vision_tower_ckpt[key[len("vision_tower") + 1:]] = value - return language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt - language_model_ckpt, multi_modal_ckpt, vision_tower_ckpt = split_checkpoint(llava_ckpt) - remapped_state_dict = { - "model.tok_embeddings.weight": language_model_ckpt.pop("model.embed_tokens.weight"), - } - remapped_state_dict.update(_translate_state_dict_for_text_model(language_model_ckpt)) - remapped_state_dict.update(_translate_state_dict_for_vision_model(vision_tower_ckpt)) - remapped_state_dict.update(_translate_state_dict_for_mm_projector_model(multi_modal_ckpt)) - return remapped_state_dict - - with torch.device("cuda"): - print("Preparing input") - pre_tokens = torch.tensor([[ 1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, - 21082, 20255, 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, - 322, 1248, 568, 6089, 304, 278, 5199, 29915, 29879, 5155, - 29889, 3148, 1001, 29901, 29871]]) - # img = prepare_image(336, 336) - post_tokens = torch.tensor([[29871, 13, 462, 9651, 1724, 526, 278, 2712, 306, 881, - 367, 274, 1300, 2738, 1048, 746, 306, 6493, 1244, 29973, - 319, 1799, 9047, 13566, 29901]]) - img = llava_image_preprocess(target_h=336, target_w=336, image_mean=[0.48145466, 0.4578275, 0.40821073], image_std=[0.26862954, 0.26130258, 0.27577711], rescale_factor=0.00392156862745098) - print(img) - - print("Done, Now creating model...") - llava_model = Model.from_params("/home/gasoonjia/torchchat/torchchat/model_params/llava-1.5.json") - - llava_model = llava_model.eval() - - print("Done. Now loading checkpoint...") - llava_ckpt = torch.load("/home/gasoonjia/executorch/examples/models/llava/llava_checkpoint.pth", map_location="cuda") - - print("Done. Now checkpoint remapping...") - remapped_state_dict = remap_llava_checkpoint(llava_ckpt) - llava_model.load_state_dict(remapped_state_dict, strict=True) - - print("Done. Now setup caches...") - - llava_model.setup_caches(1, 768) - - print("Done. Now running prefilling inference...") - # being tested, using llama_transformer - context_len, prefill_logits = llava_model(tokens=pre_tokens, encoder_input=img, post_tokens=post_tokens) - print("prefill_logits: ") - print(prefill_logits[0, -1].shape) - print(prefill_logits[0, -1]) - print("context_len: \n", context_len) - # Always generate one token at a time. - new_tokens = [torch.argmax(prefill_logits[0, -1], dim=-1).item()] - print(new_tokens) - print(prefill_logits.shape) - print("Done. Now running generation inference...") - for i in range(10): - logits = llava_model( - torch.tensor([new_tokens[i]]), input_pos=torch.tensor([context_len + i]) - ) - print(f"{i}-th logits: ") - print(logits) - - print(f"{i}-th logits.shape: ") - print(logits.shape) - new_tokens.append(torch.argmax(logits[-1, :]).item()) - - print("Done. The output is:", new_tokens)