Skip to content

Commit

Permalink
Exllama cache 8bit (#2719)
Browse files Browse the repository at this point in the history
  • Loading branch information
mjkaye authored Nov 23, 2023
1 parent ff25295 commit 6ac7d76
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/exllama_v2.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ python3 -m fastchat.serve.model_worker \
--exllama-gpu-split 18,24
```

`--exllama-cache-8bit` can be used to enable 8-bit caching with exllama and save some VRAM.

## Performance

Reference: https://github.com/turboderp/exllamav2#performance
Expand Down
5 changes: 5 additions & 0 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def add_model_args(parser):
default=None,
help="Used for exllamabv2. Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7",
)
parser.add_argument(
"--exllama-cache-8bit",
action="store_true",
help="Used for exllamabv2. Use 8-bit cache to save VRAM.",
)
parser.add_argument(
"--enable-xft",
action="store_true",
Expand Down
6 changes: 5 additions & 1 deletion fastchat/modules/exllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class ExllamaConfig:
max_seq_len: int
gpu_split: str = None
cache_8bit: bool = False


class ExllamaModel:
Expand All @@ -22,6 +23,7 @@ def load_exllama_model(model_path, exllama_config: ExllamaConfig):
ExLlamaV2Tokenizer,
ExLlamaV2,
ExLlamaV2Cache,
ExLlamaV2Cache_8bit,
)
except ImportError as e:
print(f"Error: Failed to load Exllamav2. {e}")
Expand All @@ -31,6 +33,7 @@ def load_exllama_model(model_path, exllama_config: ExllamaConfig):
exllamav2_config.model_dir = model_path
exllamav2_config.prepare()
exllamav2_config.max_seq_len = exllama_config.max_seq_len
exllamav2_config.cache_8bit = exllama_config.cache_8bit

exllama_model = ExLlamaV2(exllamav2_config)
tokenizer = ExLlamaV2Tokenizer(exllamav2_config)
Expand All @@ -40,7 +43,8 @@ def load_exllama_model(model_path, exllama_config: ExllamaConfig):
split = [float(alloc) for alloc in exllama_config.gpu_split.split(",")]
exllama_model.load(split)

exllama_cache = ExLlamaV2Cache(exllama_model)
cache_class = ExLlamaV2Cache_8bit if exllamav2_config.cache_8bit else ExLlamaV2Cache
exllama_cache = cache_class(exllama_model)
model = ExllamaModel(exllama_model=exllama_model, exllama_cache=exllama_cache)

return model, tokenizer
1 change: 1 addition & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def main(args):
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
Expand Down
1 change: 1 addition & 0 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ def create_model_worker():
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
Expand Down
1 change: 1 addition & 0 deletions fastchat/serve/multi_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def create_multi_model_worker():
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
Expand Down

0 comments on commit 6ac7d76

Please sign in to comment.