Skip to content

Commit

Permalink
Implement lora support (theroyallab#24)
Browse files Browse the repository at this point in the history
* Model: Implement basic lora support

* Add ability to load loras from config on launch
* Supports loading multiple loras and lora scaling
* Add function to unload loras

* Colab: Update for basic lora support

* Model: Test vram alloc after lora load, add docs

* Git: Add loras folder to .gitignore

* API: Add basic lora-related endpoints

* Add /loras/ endpoint for querying available loras
* Add /model/lora endpoint for querying currently loaded loras
* Add /model/lora/load endpoint for loading loras
* Add /model/lora/unload endpoint for unloading loras
* Move lora config-checking logic to main.py for better compat with API endpoints

* Revert bad CRLF line ending changes

* API: Add basic lora-related endpoints (fixed)

* Add /loras/ endpoint for querying available loras
* Add /model/lora endpoint for querying currently loaded loras
* Add /model/lora/load endpoint for loading loras
* Add /model/lora/unload endpoint for unloading loras
* Move lora config-checking logic to main.py for better compat with API endpoints

* Model: Unload loras first when unloading model

* API + Models: Cleanup lora endpoints and functions

Condenses down endpoint and model load code. Also makes the routes
behave the same way as model routes to help not confuse the end user.

Signed-off-by: kingbri <bdashore3@proton.me>

* Loras: Optimize load endpoint

Return successes and failures along with consolidating the request
to the rewritten load_loras function.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Co-authored-by: kingbri <bdashore3@proton.me>
Co-authored-by: DocShotgun <126566557+DocShotgun@users.noreply.github.com>
  • Loading branch information
DocShotgun and bdashore3 committed Dec 9, 2023
1 parent 161c9d2 commit 7380a3b
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 19 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,7 @@ api_tokens.yml
# Models folder
models/*
!models/place_your_models_here.txt

# Loras folder
loras/*
!loras/place_your_loras_here.txt
25 changes: 25 additions & 0 deletions OAI/types/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from pydantic import BaseModel, Field;
from time import time
from typing import Optional, List

class LoraCard(BaseModel):
id: str = "test"
object: str = "lora"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
scaling: Optional[float] = None

class LoraList(BaseModel):
object: str = "list"
data: List[LoraCard] = Field(default_factory=list)

class LoraLoadInfo(BaseModel):
name: str
scaling: Optional[float] = 1.0

class LoraLoadRequest(BaseModel):
loras: List[LoraLoadInfo]

class LoraLoadResponse(BaseModel):
success: List[str] = Field(default_factory=list)
failure: List[str] = Field(default_factory=list)
12 changes: 11 additions & 1 deletion OAI/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
ChatCompletionStreamChoice
)
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List
from typing import Optional, List, Dict

# Check fastchat
try:
Expand Down Expand Up @@ -100,6 +101,15 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):

return model_card_list

def get_lora_list(lora_path: pathlib.Path):
lora_list = LoraList()
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id = path.name)
lora_list.data.append(lora_card)

return lora_list

def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage]):

# Check if fastchat is available
Expand Down
34 changes: 29 additions & 5 deletions TabbyAPI_Colab_Example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,17 @@
"# @markdown Select model:\n",
"repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n",
"revision = \"4bpw\" # @param {type:\"string\"}\n",
"if revision == \"\": revision = \"main\"\n",
"# @markdown ---\n",
"# @markdown Select draft model (optional, for speculative decoding):\n",
"draft_repo_id = \"\" # @param {type:\"string\"}\n",
"draft_revision = \"\" # @param {type:\"string\"}\n",
"if draft_revision == \"\": draft_revision = \"main\"\n",
"# @markdown ---\n",
"# @markdown Select lora (optional):\n",
"lora_repo_id = \"\" # @param {type:\"string\"}\n",
"lora_revision = \"\" # @param {type:\"string\"}\n",
"if lora_revision == \"\": lora_revision = \"main\"\n",
"# @markdown ---\n",
"\n",
"# Install tabbyAPI\n",
Expand All @@ -62,8 +69,15 @@
"%cd /content/tabbyAPI/\n",
"\n",
"from huggingface_hub import snapshot_download\n",
"\n",
"snapshot_download(repo_id=repo_id, revision=revision, local_dir=f\"./models/{repo_id.replace('/', '_')}\")\n",
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")"
"model = repo_id.replace('/', '_')\n",
"\n",
"if len(draft_repo_id) > 0: snapshot_download(repo_id=draft_repo_id, revision=draft_revision, local_dir=f\"./models/{draft_repo_id.replace('/', '_')}\")\n",
"draft_model = draft_repo_id.replace('/', '_')\n",
"\n",
"if len(lora_repo_id) > 0: snapshot_download(repo_id=lora_repo_id, revision=lora_revision, local_dir=f\"./loras/{lora_repo_id.replace('/', '_')}\")\n",
"lora = lora_repo_id.replace('/', '_')"
]
},
{
Expand All @@ -77,9 +91,6 @@
"# @title # Configure and launch API { display-mode: \"form\" }\n",
"# @markdown ---\n",
"# @markdown Model parameters:\n",
"\n",
"model = repo_id.replace('/', '_')\n",
"draft_model = draft_repo_id.replace('/', '_')\n",
"ContextSize = 4096 # @param {type:\"integer\"}\n",
"RopeScale = 1.0 # @param {type:\"number\"}\n",
"RopeAlpha = 1.0 # @param {type:\"number\"}\n",
Expand All @@ -88,6 +99,9 @@
"DraftRopeScale = 1.0 # @param {type:\"number\"}\n",
"DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n",
"# @markdown Lora parameters (optional, for loras):\n",
"LoraScaling = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n",
"# @markdown Misc options:\n",
"CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n",
"UseDummyModels = False # @param {type:\"boolean\"}\n",
Expand Down Expand Up @@ -161,6 +175,16 @@
" # Rope parameters for draft models (default: 1.0)\n",
" draft_rope_scale: {DraftRopeScale}\n",
" draft_rope_alpha: {DraftRopeAlpha}\n",
"\n",
" # Options for loras\n",
" lora:\n",
" # Overrides the directory to look for loras (default: loras)\n",
" lora_dir: loras\n",
"\n",
" # List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.\n",
" loras:\n",
" - name: {lora}\n",
" scaling: {LoraScaling}\n",
"'''\n",
"with open(\"./config.yml\", \"w\") as file:\n",
" file.write(write)\n",
Expand Down Expand Up @@ -188,4 +212,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}
14 changes: 14 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,17 @@ model:
# Rope parameters for draft models (default: 1.0)
draft_rope_scale: 1.0
draft_rope_alpha: 1.0

# Options for loras
lora:
# Overrides the directory to look for loras (default: loras)
lora_dir: Your lora directory path

# List of loras to load and associated scaling factors (default: 1.0). Comment out unused entries or add more rows as needed.
loras:
- name: lora1
scaling: 1.0
- name: lora2
scaling: 0.9
- name: lora3
scaling: 0.5
Empty file added loras/place_your_loras_here.txt
Empty file.
67 changes: 65 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from generators import generate_with_semaphore
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import (
TokenEncodeRequest,
Expand All @@ -21,6 +22,7 @@
from OAI.utils import (
create_completion_response,
get_model_list,
get_lora_list,
get_chat_completion_prompt,
create_chat_completion_response,
create_chat_completion_stream_chunk
Expand Down Expand Up @@ -87,7 +89,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not data.name:
raise HTTPException(400, "model_name not found.")

# TODO: Move this to model_container
model_config = config.get("model") or {}
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / data.name
Expand Down Expand Up @@ -160,7 +161,63 @@ async def unload_model():
global model_container

model_container.unload()
model_container = None
model_container = None

# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_all_loras():
model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_path = pathlib.Path(lora_config.get("lora_dir") or "loras")

loras = get_lora_list(lora_path.resolve())

return loras

# Currently loaded loras endpoint
@app.get("/v1/lora", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_active_loras():
active_loras = LoraList(
data = list(map(
lambda lora: LoraCard(
id = pathlib.Path(lora.lora_path).parent.name,
scaling = lora.lora_scaling * lora.lora_r / lora.lora_alpha
),
model_container.active_loras
)
))

return active_loras

# Load lora endpoint
@app.post("/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
async def load_model(data: LoraLoadRequest):
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")

model_config = config.get("model") or {}
lora_config = model_config.get("lora") or {}
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
if not lora_dir.exists():
raise HTTPException(400, "A parent lora directory does not exist. Check your config.yml?")

# Clean-up existing loras if present
if len(model_container.active_loras) > 0:
model_container.unload(True)

result = model_container.load_loras(lora_dir, **data.dict())
return LoraLoadResponse(
success = result.get("success") or [],
failure = result.get("failure") or []
)

# Unload lora endpoint
@app.get("/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(_check_model_container)])
async def unload_loras():
global model_container

model_container.unload(True)

# Encode tokens endpoint
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
Expand Down Expand Up @@ -308,6 +365,12 @@ async def generator():
else:
loading_bar.next()

# Load loras
lora_config = model_config.get("lora") or {}
if "loras" in lora_config:
lora_dir = pathlib.Path(lora_config.get("lora_dir") or "loras")
model_container.load_loras(lora_dir.resolve(), **lora_config)

network_config = config.get("network") or {}
uvicorn.run(
app,
Expand Down
60 changes: 49 additions & 11 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
ExLlamaV2Tokenizer,
ExLlamaV2Lora
)
from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
Expand All @@ -30,6 +31,8 @@ class ModelContainer:
cache_fp8: bool = False
gpu_split_auto: bool = True
gpu_split: list or None = None

active_loras: List[ExLlamaV2Lora] = []

def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs):
"""
Expand All @@ -54,6 +57,8 @@ def progress(loaded_modules: int, total_modules: int, loading_draft: bool)
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft model.
By default, the draft model's alpha value is calculated automatically to scale to the size of the
full model.
'lora_dir' (str): Lora directory
'loras' (list[dict]): List of loras to be loaded, consisting of 'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some) tensors, per device
'no_flash_attn' (bool): Turns off flash attention (increases vram usage) (default: False)
Expand Down Expand Up @@ -141,6 +146,32 @@ def progress(loaded_modules: int, total_modules: int)
"""
for _ in self.load_gen(progress_callback): pass

def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""

loras = kwargs.get("loras") or []
success: List[str] = []
failure: List[str] = []

for lora in loras:
lora_name = lora.get("name") or None
lora_scaling = lora.get("scaling") or 1.0

if lora_name is None:
print("One of your loras does not have a name. Please check your config.yml! Skipping lora load.")
failure.append(lora_name)
continue

print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
self.active_loras.append(ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling))
print("Lora successfully loaded.")
success.append(lora_name)

# Return success and failure names
return { 'success': success, 'failure': failure }

def load_gen(self, progress_callback = None):
"""
Expand Down Expand Up @@ -204,23 +235,30 @@ def progress(loaded_modules: int, total_modules: int)
print("Model successfully loaded.")


def unload(self):
def unload(self, loras_only: bool = False):
"""
Free all VRAM resources used by this model
"""

if self.model: self.model.unload()
self.model = None
if self.draft_model: self.draft_model.unload()
self.draft_model = None
self.config = None
self.cache = None
self.tokenizer = None
self.generator = None
for lora in self.active_loras:
lora.unload()

self.active_loras = []

# Unload the entire model if not just unloading loras
if not loras_only:
if self.model: self.model.unload()
self.model = None
if self.draft_model: self.draft_model.unload()
self.draft_model = None
self.config = None
self.cache = None
self.tokenizer = None
self.generator = None

gc.collect()
torch.cuda.empty_cache()


# Common function for token operations
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
if text:
Expand Down Expand Up @@ -381,7 +419,7 @@ def generate_gen(self, prompt: str, **kwargs):
active_ids = ids[:, max(0, overflow):]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]

self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing)
self.generator.begin_stream(active_ids, gen_settings, token_healing = token_healing, loras = self.active_loras)

# Generate

Expand Down

0 comments on commit 7380a3b

Please sign in to comment.