Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Peft models to share their base model #1905

Merged
merged 4 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/model_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
- [baichuan-inc/baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B)
- [internlm/internlm-chat-7b](https://huggingface.co/internlm/internlm-chat-7b)
- Any [EleutherAI](https://huggingface.co/EleutherAI) pythia model such as [pythia-6.9b](https://huggingface.co/EleutherAI/pythia-6.9b)
- Any [Peft](https://github.com/huggingface/peft) adapter trained ontop of a model above. To activate, must have `peft` in the model path.
- Any [Peft](https://github.com/huggingface/peft) adapter trained on top of a
model above. To activate, must have `peft` in the model path. Note: If
loading multiple peft models, you can have them share the base model weights by
setting the environment variable `PEFT_SHARE_BASE_WEIGHTS=true` in any model
worker.

## How to support a new model

Expand Down
71 changes: 69 additions & 2 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import sys
from typing import List, Optional
from typing import Dict, List, Optional
import warnings

if sys.version_info >= (3, 9):
Expand All @@ -11,6 +11,7 @@
from functools import lru_cache as cache

import accelerate
import os
import psutil
import torch
from transformers import (
Expand All @@ -35,6 +36,12 @@
)
from fastchat.utils import get_gpu_memory

# Check an environment variable to check if we should be sharing Peft model
# weights. When false we treat all Peft models as separate.
peft_share_base_weights = (
os.environ.get("PEFT_SHARE_BASE_WEIGHTS", "false").lower() == "true"
)


class BaseModelAdapter:
"""The base and the default model adapter."""
Expand Down Expand Up @@ -254,6 +261,33 @@ def get_generate_stream_function(model: torch.nn.Module, model_path: str):
return generate_stream_falcon
elif is_codet5p:
return generate_stream_codet5p
elif peft_share_base_weights and "peft" in model_path:
# Return a curried stream function that loads the right adapter
# according to the model_name available in this context. This ensures
# the right weights are available.
@torch.inference_mode()
def generate_stream_peft(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
):
model.set_adapter(model_path)
for x in generate_stream(
model,
tokenizer,
params,
device,
context_len,
stream_interval,
judge_sent_end,
):
yield x

return generate_stream_peft
else:
return generate_stream

Expand Down Expand Up @@ -331,6 +365,9 @@ def remove_parent_directory_name(model_path):
return model_path.split("/")[-1]


peft_model_cache = {}


class PeftModelAdapter:
"""Loads any "peft" model and it's base model."""

Expand All @@ -349,12 +386,42 @@ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
f"PeftModelAdapter cannot load a base model with 'peft' in the name: {config.base_model_name_or_path}"
)

# Basic proof of concept for loading peft adapters that share the base
# weights. This is pretty messy because Peft re-writes the underlying
# base model and internally stores a map of adapter layers.
# So, to make this work we:
# 1. Cache the first peft model loaded for a given base models.
# 2. Call `load_model` for any follow on Peft models.
# 3. Make sure we load the adapters by the model_path. Why? This is
# what's accessible during inference time.
# 4. In get_generate_stream_function, make sure we load the right
# adapter before doing inference. This *should* be safe when calls
# are blocked the same semaphore.
if peft_share_base_weights:
if base_model_path in peft_model_cache:
model, tokenizer = peft_model_cache[base_model_path]
# Super important: make sure we use model_path as the
# `adapter_name`.
model.load_adapter(model_path, adapter_name=model_path)
else:
base_adapter = get_model_adapter(base_model_path)
base_model, tokenizer = base_adapter.load_model(
base_model_path, from_pretrained_kwargs
)
# Super important: make sure we use model_path as the
# `adapter_name`.
model = PeftModel.from_pretrained(
base_model, model_path, adapter_name=model_path
)
peft_model_cache[base_model_path] = (model, tokenizer)
return model, tokenizer

# In the normal case, load up the base model weights again.
base_adapter = get_model_adapter(base_model_path)
base_model, tokenizer = base_adapter.load_model(
base_model_path, from_pretrained_kwargs
)
model = PeftModel.from_pretrained(base_model, model_path)

return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
Expand Down