Skip to content

Commit

Permalink
Update scripts and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
haotian-liu committed Jun 11, 2023
1 parent 700863b commit 9221723
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 17 deletions.
27 changes: 27 additions & 0 deletions docs/LoRA.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# LLaVA (LoRA, Technical Preview)

NOTE: This is a technical preview, and is not yet ready for production use.

## Demo (Web UI)

Please execute each of the command below one by one (after the previous one has finished). The commands are the same as launching other demos except for an additional `--model-base` flag to specify the base model to use. Please make sure the base model corresponds to the LoRA checkpoint that you are using. For this technical preview, you need Vicuna v1.1 (7B) checkpoint (if you do not have that already, follow the instructions [here](https://github.com/lm-sys/FastChat#vicuna-weights)).

#### Launch a controller
```Shell
python -m llava.serve.controller --host 0.0.0.0 --port 10000
```

#### Launch a model worker
```Shell
python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port 40000 --worker http://localhost:40000 --model-path liuhaotian/llava-vicuna-7b-v1.1-lcs_558k-instruct_80k_3e-lora-preview-alpha --model-base /path/to/vicuna-v1.1
```
Wait until the process finishes loading the model and you see "Uvicorn running on ...".

#### Launch a gradio web server.
```Shell
python -m llava.serve.gradio_web_server --controller http://localhost:10000
```

## Training

Please see training scripts at [`./scripts/deepspeed`](./scripts/deepspeed).
73 changes: 57 additions & 16 deletions llava/eval/model_vqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,46 @@ def eval_model(args):
# Model
disable_torch_init()
model_name = os.path.expanduser(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if 'lora' in model_name.lower():
lora_cfg_pretrained = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(args.base_model_path)
else:
tokenizer = AutoTokenizer.from_pretrained(model_name)
if args.mm_projector is None:
patch_config(model_name)
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
if 'lora' in model_name.lower():
print('Loading LLaVA from base model...')
llama_state_dict = AutoModelForCausalLM.from_pretrained(args.base_model_path, torch_dtype=torch.float16).state_dict()
model = LlavaLlamaForCausalLM.from_pretrained(args.base_model_path, config=lora_cfg_pretrained, state_dict=llama_state_dict, torch_dtype=torch.float16, ignore_mismatched_sizes=True)

print('Loading LLaVA trainable weights...')
if os.path.exists(os.path.join(model_name, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_name, 'non_lora_trainables.bin'), map_location='cpu')
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder)
return torch.load(cache_file, map_location='cpu')
non_lora_trainables = load_from_hf(model_name, 'non_lora_trainables.bin')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.embed_tokens') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
non_lora_trainables = {k: v.to(torch.float16) for k, v in non_lora_trainables.items()}
model.load_state_dict(non_lora_trainables, strict=False)

from peft import PeftModel
print('Loading LoRA weights...')
model = PeftModel.from_pretrained(model, model_name)
print('Merging LoRA weights...')
model = model.merge_and_unload()
print('Moving to CUDA...')
model = model.cuda()
else:
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()
image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
Expand All @@ -73,7 +109,7 @@ def eval_model(args):
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
else:
# in case of using a pretrained model with only a MLP projector weights
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).cuda()
model = LlavaLlamaForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, use_cache=True).cuda()

vision_tower = CLIPVisionModel.from_pretrained(args.vision_tower, torch_dtype=torch.float16).cuda()
image_processor = CLIPImageProcessor.from_pretrained(args.vision_tower, torch_dtype=torch.float16)
Expand Down Expand Up @@ -113,16 +149,14 @@ def eval_model(args):
else:
qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len

if args.conv_mode == 'simple_legacy':
qs += '\n\n### Response:'
# conv = default_conversation.copy()
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
if args.conv_mode != 'simple':
conv.append_message(conv.roles[1], "")
prompt = conv.get_prompt()
inputs = tokenizer([prompt])

image = Image.open(os.path.join(args.image_folder, image_file))
# image.save(os.path.join(save_image_folder, image_file))
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]

input_ids = torch.as_tensor(inputs.input_ids).cuda()
Expand All @@ -145,7 +179,10 @@ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kw
return True
return False

keywords = ['###']
if args.conv_mode == 'simple':
keywords = ['###']
else:
keywords = [conv.sep2]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

with torch.inference_mode():
Expand All @@ -155,15 +192,16 @@ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kw
do_sample=True,
temperature=0.7,
max_new_tokens=1024,
use_cache=True,
stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] Sample {i}: {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0].strip()

if args.conv_mode == 'simple_legacy' or args.conv_mode == 'simple':
if args.conv_mode == 'simple':
while True:
cur_len = len(outputs)
outputs = outputs.strip()
Expand All @@ -173,13 +211,15 @@ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kw
if len(outputs) == cur_len:
break

try:
index = outputs.index(conv.sep)
except ValueError:
outputs += conv.sep
index = outputs.index(conv.sep)
try:
index = outputs.index(conv.sep)
except ValueError:
outputs += conv.sep
index = outputs.index(conv.sep)

outputs = outputs[:index].strip()
outputs = outputs[:index].strip()
else:
outputs = outputs.strip()

ans_id = shortuuid.uuid()
ans_file.write(json.dumps({"question_id": idx,
Expand All @@ -194,6 +234,7 @@ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kw
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--base-model-path", type=str, default=None)
parser.add_argument("--image-folder", type=str, default="")
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
Expand Down
13 changes: 12 additions & 1 deletion llava/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,18 @@ def load_model(model_path, model_base, model_name, num_gpus):
model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim))

print('Loading LLaVA trainable weights...')
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
else:
# this is probably from HF Hub
from huggingface_hub import hf_hub_download
def load_from_hf(repo_id, filename, subfolder=None):
cache_file = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder)
return torch.load(cache_file, map_location='cpu')
non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
if any(k.startswith('model.model.embed_tokens') for k in non_lora_trainables):
non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"accelerate", "einops", "fastapi", "gradio==3.23", "markdown2[all]", "numpy",
"requests", "sentencepiece", "tokenizers==0.12.1",
"torch", "torchvision", "uvicorn", "wandb",
"shortuuid",
"deepspeed==0.9.2", "peft==0.3.0",
"transformers @ git+https://github.com/huggingface/transformers.git@cae78c46"
]
Expand Down

0 comments on commit 9221723

Please sign in to comment.