Skip to content

Commit

Permalink
Quantization Support (#316)
Browse files Browse the repository at this point in the history
* Add Windows Setup Help

* Optimize documents/bootscripts for Windows User

* Correct some description

* Fix dependecies

* fish 1.2 webui & api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Fix CUDA env

* Update api usage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adapt finetuning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Quantization Support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] authored Jul 3, 2024
1 parent 16da313 commit ea53678
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 44 deletions.
21 changes: 20 additions & 1 deletion fish_speech/models/text2semantic/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
import torch.nn as nn
from einops import rearrange
from loguru import logger
from torch import Tensor
from torch.nn import functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
Expand Down Expand Up @@ -320,7 +321,7 @@ def from_pretrained(
lora_config: LoraConfig | None = None,
rope_base: int | None = None,
) -> "BaseTransformer":
config = BaseModelArgs.from_pretrained(path)
config = BaseModelArgs.from_pretrained(str(path))
if max_length is not None:
config.max_seq_len = max_length
log.info(f"Override max_seq_len to {max_length}")
Expand Down Expand Up @@ -348,6 +349,24 @@ def from_pretrained(
if load_weights is False:
log.info("Randomly initialized model")
else:

if "int8" in str(Path(path)):
logger.info("Using int8 weight-only quantization!")
from tools.llama.quantize import WeightOnlyInt8QuantHandler

simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()

if "int4" in str(Path(path)):
logger.info("Using int4 quantization!")
path_comps = path.name.split("-")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from tools.llama.quantize import WeightOnlyInt4QuantHandler

simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()

weights = torch.load(
Path(path) / "model.pth", map_location="cpu", mmap=True
)
Expand Down
81 changes: 71 additions & 10 deletions fish_speech/webui/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,16 @@ def fresh_tb_dir():


def list_decoder_models():
paths = [str(p) for p in Path("checkpoints").glob("vq*.*")] + [
str(p) for p in Path("results").glob("vqgan*/**/*.ckpt")
]
paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
if not paths:
logger.warning("No decoder model found")
return paths


def list_llama_models():
choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")]
choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
if not choices:
logger.warning("No LLaMA model found")
return choices
Expand Down Expand Up @@ -593,11 +593,7 @@ def fresh_llama_ckpt(llama_use_lora):


def fresh_llama_model():
choices = [
str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
]
choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
return gr.Dropdown(choices=choices)
return gr.Dropdown(choices=list_llama_models())


def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
Expand Down Expand Up @@ -627,6 +623,39 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
return build_html_ok_message(i18n("Merge successfully"))


def llama_quantify(llama_weight, quantify_mode):
if llama_weight is None or not Path(llama_weight).exists():
return build_html_error_message(
i18n(
"Path error, please check the model file exists in the corresponding path"
)
)
now = generate_folder_name()
quantify_cmd = [
PYTHON,
"tools/llama/quantize.py",
"--checkpoint-path",
llama_weight,
"--mode",
quantify_mode,
"--timestamp",
now,
]
logger.info(quantify_cmd)
subprocess.run(quantify_cmd)
if quantify_mode == "int8":
quantize_path = str(
Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
)
else:
quantize_path = str(
Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
)
return build_html_ok_message(
i18n("Quantify successfully") + f"Path: {quantize_path}"
)


init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
init_llama_yml = load_yaml_data_in_fact(llama_yml_path)

Expand Down Expand Up @@ -907,7 +936,34 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
value=i18n("Merge"), variant="primary"
)

with gr.Tab(label="Tensorboard", id=5):
with gr.Tab(label=i18n("Model Quantization"), id=5):
with gr.Row(equal_height=False):
llama_weight_to_quantify = gr.Dropdown(
label=i18n("Base LLAMA Model"),
info=i18n(
"Type the path or select from the dropdown"
),
choices=list_llama_models(),
value="checkpoints/fish-speech-1.2",
allow_custom_value=True,
interactive=True,
)
quantify_mode = gr.Dropdown(
label=i18n("Post-quantification Precision"),
info=i18n(
"The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
),
choices=["int8", "int4"],
value="int8",
allow_custom_value=False,
interactive=True,
)
with gr.Row(equal_height=False):
llama_quantify_btn = gr.Button(
value=i18n("Quantify"), variant="primary"
)

with gr.Tab(label="Tensorboard", id=6):
with gr.Row(equal_height=False):
tb_host = gr.Textbox(
label=i18n("Tensorboard Host"), value="127.0.0.1"
Expand Down Expand Up @@ -1122,6 +1178,11 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
outputs=[train_error],
)
llama_quantify_btn.click(
fn=llama_quantify,
inputs=[llama_weight_to_quantify, quantify_mode],
outputs=[train_error],
)
infer_checkbox.change(
fn=change_infer,
inputs=[
Expand Down
17 changes: 0 additions & 17 deletions tools/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,23 +343,6 @@ def load_model(checkpoint_path, device, precision, compile=False):
checkpoint_path, load_weights=True
)

if "int8" in str(checkpoint_path):
logger.info("Using int8 weight-only quantization!")
from .quantize import WeightOnlyInt8QuantHandler

simple_quantizer = WeightOnlyInt8QuantHandler(model)
model = simple_quantizer.convert_for_runtime()

if "int4" in str(checkpoint_path):
logger.info("Using int4 quantization!")
path_comps = checkpoint_path.name.split(".")
assert path_comps[-2].startswith("g")
groupsize = int(path_comps[-2][1:])
from .quantize import WeightOnlyInt4QuantHandler

simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
model = simple_quantizer.convert_for_runtime()

model = model.to(device=device, dtype=precision)
logger.info(f"Restored model from checkpoint")

Expand Down
40 changes: 24 additions & 16 deletions tools/llama/quantize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
import datetime
import shutil

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -11,7 +13,8 @@
import torch.nn as nn
import torch.nn.functional as F

from .generate import load_model
from fish_speech.models.text2semantic.llama import find_multiple
from tools.llama.generate import load_model

##### Quantization Primitives ######

Expand Down Expand Up @@ -415,23 +418,26 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
)


def generate_folder_name():
now = datetime.datetime.now()
folder_name = now.strftime("%Y%m%d_%H%M%S")
return folder_name


@click.command()
@click.option(
"--checkpoint-path",
type=click.Path(path_type=Path, exists=True),
default="checkpoints/fish-speech-1.2",
)
@click.option("--config-name", type=str, default="dual_ar_2_codebook_medium")
@click.option(
"--mode", type=str, default="int8", help="type of quantization to perform"
)
@click.option(
"--groupsize", type=int, default=128, help="Group size for int4 quantization."
)
def quantize(
checkpoint_path: Path, config_name: str, mode: str, groupsize: int
) -> None:
assert checkpoint_path.is_file(), checkpoint_path
@click.option("--timestamp", type=str, default="None", help="When to do quantization")
def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -> None:

device = "cpu"
precision = torch.bfloat16
Expand All @@ -440,13 +446,13 @@ def quantize(
t0 = time.time()

model, _ = load_model(
config_name,
checkpoint_path=checkpoint_path,
device=device,
precision=precision,
compile=False,
max_length=2048,
)
vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
now = timestamp if timestamp != "None" else generate_folder_name()

if mode == "int8":
print(
Expand All @@ -455,10 +461,11 @@ def quantize(
quant_handler = WeightOnlyInt8QuantHandler(model)
quantized_state_dict = quant_handler.create_quantized_state_dict()

dir_name = checkpoint_path.parent
base_name = checkpoint_path.stem
suffix = checkpoint_path.suffix
quantize_path = dir_name / f"{base_name}.int8{suffix}"
dir_name = checkpoint_path
dst_name = Path(f"checkpoints/fs-1.2-int8-{now}")
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
(dst_name / vq_model).unlink()
quantize_path = dst_name / "model.pth"

elif mode == "int4":
print(
Expand All @@ -467,10 +474,11 @@ def quantize(
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
quantized_state_dict = quant_handler.create_quantized_state_dict()

dir_name = checkpoint_path.parent
base_name = checkpoint_path.name
suffix = checkpoint_path.suffix
quantize_path = dir_name / f"{base_name}.int4.g{groupsize}{suffix}"
dir_name = checkpoint_path
dst_name = Path(f"checkpoints/fs-1.2-int4-g{groupsize}-{now}")
shutil.copytree(str(dir_name.resolve()), str(dst_name.resolve()))
(dst_name / vq_model).unlink()
quantize_path = dst_name / "model.pth"

else:
raise ValueError(
Expand Down

0 comments on commit ea53678

Please sign in to comment.