Skip to content

Commit

Permalink
Adapt finetuning (#314)
Browse files Browse the repository at this point in the history
* Add Windows Setup Help

* Optimize documents/bootscripts for Windows User

* Correct some description

* Fix dependecies

* fish 1.2 webui & api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix spelling

* Fix CUDA env

* Update api usage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Adapt finetuning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] committed Jul 3, 2024
1 parent 74d7850 commit 3d6d1d7
Showing 1 changed file with 19 additions and 36 deletions.
55 changes: 19 additions & 36 deletions fish_speech/webui/manage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import datetime
import html
import json
import os
Expand Down Expand Up @@ -180,8 +181,6 @@ def change_infer(
infer_decoder_config,
"--llama-checkpoint-path",
infer_llama_model,
"--tokenizer",
"checkpoints/fish-speech-1.2",
]
+ (["--compile"] if infer_compile == "Yes" else []),
env=env,
Expand Down Expand Up @@ -400,6 +399,12 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
)


def generate_folder_name():
now = datetime.datetime.now()
folder_name = now.strftime("%Y%m%d_%H%M%S")
return folder_name


def train_process(
data_path: str,
option: str,
Expand All @@ -419,12 +424,6 @@ def train_process(
llama_use_speaker,
llama_use_lora,
):
import datetime

def generate_folder_name():
now = datetime.datetime.now()
folder_name = now.strftime("%Y%m%d_%H%M%S")
return folder_name

backend = "nccl" if sys.platform == "linux" else "gloo"

Expand Down Expand Up @@ -464,14 +463,9 @@ def generate_folder_name():
"16",
]
)
ckpt_path = (
"text2semantic-sft-medium-v1.1-4k.pth"
if llama_base_config == "dual_ar_2_codebook_medium"
else "text2semantic-sft-large-v1.1-4k.pth"
)
ckpt_path = "checkpoints/fish-speech-1.2/model.pth"
lora_prefix = "lora_" if llama_use_lora else ""
llama_size = "large_" if ("large" in llama_base_config) else "medium_"
llama_name = lora_prefix + "text2semantic_" + llama_size + new_project
llama_name = lora_prefix + "text2semantic_" + new_project
latest = next(
iter(
sorted(
Expand Down Expand Up @@ -500,10 +494,7 @@ def generate_folder_name():
"--config-name",
"text2semantic_finetune",
f"project={project}",
f"ckpt_path=checkpoints/{ckpt_path}",
f"trainer.strategy.process_group_backend={backend}",
f"model@model.model={llama_base_config}",
"tokenizer.pretrained_model_name_or_path=checkpoints",
f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
f"model.optimizer.lr={llama_lr}",
Expand All @@ -514,8 +505,8 @@ def generate_folder_name():
f"trainer.precision={llama_precision}",
f"trainer.val_check_interval={llama_check_interval}",
f"trainer.accumulate_grad_batches={llama_grad_batches}",
f"train_dataset.use_speaker={llama_use_speaker}",
] + ([f"+lora@model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
f"train_dataset.interactive_prob={llama_use_speaker}",
] + ([f"+lora@model.model.lora_config=r_8_alpha_16"] if llama_use_lora else [])
logger.info(train_cmd)
subprocess.run(train_cmd)

Expand Down Expand Up @@ -573,10 +564,7 @@ def list_decoder_models():


def list_llama_models():
choices = [
str(p).replace("\\", "/") for p in Path("checkpoints").glob("text2sem*.*")
]
choices += [str(p) for p in Path("results").glob("text2sem*/**/*.ckpt")]
choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*.pth")]
if not choices:
logger.warning("No LLaMA model found")
return choices
Expand Down Expand Up @@ -627,16 +615,12 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
merge_cmd = [
PYTHON,
"tools/llama/merge_lora.py",
"--llama-config",
lora_llama_config,
"--lora-config",
"r_8_alpha_16",
"--llama-weight",
llama_weight,
"--lora-weight",
lora_weight,
"--output",
llama_lora_output,
llama_lora_output + "_" + generate_folder_name(),
]
logger.info(merge_cmd)
subprocess.run(merge_cmd)
Expand Down Expand Up @@ -759,6 +743,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
"Use LoRA can save GPU memory, but may reduce the quality of the model"
),
value=True,
interactive=False,
)
llama_ckpt = gr.Dropdown(
label=i18n("Select LLAMA ckpt"),
Expand Down Expand Up @@ -792,7 +777,6 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
llama_base_config = gr.Dropdown(
label=i18n("Model Size"),
choices=[
"text2semantic_agent",
"text2semantic_finetune",
],
value="text2semantic_finetune",
Expand Down Expand Up @@ -865,7 +849,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
maximum=1.0,
step=0.05,
value=init_llama_yml["train_dataset"][
"use_speaker"
"interactive_prob"
],
)

Expand All @@ -879,7 +863,7 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
choices=[
"checkpoints/fish-speech-1.2/model.pth",
],
value=init_llama_yml["ckpt_path"],
value="checkpoints/fish-speech-1.2/model.pth",
allow_custom_value=True,
interactive=True,
)
Expand All @@ -902,10 +886,9 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
"Type the path or select from the dropdown"
),
choices=[
"text2semantic_agent",
"text2semantic_finetune",
],
value="text2semantic_agent",
value="text2semantic_finetune",
allow_custom_value=True,
)
with gr.Row(equal_height=False):
Expand All @@ -914,8 +897,8 @@ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_ou
info=i18n(
"Type the path or select from the dropdown"
),
value="checkpoints/merged.ckpt",
choices=["checkpoints/merged.ckpt"],
value="checkpoints/merged",
choices=["checkpoints/merged"],
allow_custom_value=True,
interactive=True,
)
Expand Down

0 comments on commit 3d6d1d7

Please sign in to comment.