Skip to content

Commit

Permalink
removes speaker encoder ckpt (metavoiceio#51)
Browse files Browse the repository at this point in the history
* feat: remove speaker encoder cpkt

* fix: missing attr

---------

Co-authored-by: sid <sid@themetavoice.xyz>
  • Loading branch information
sidroopdaska and sid authored Feb 13, 2024
1 parent d10f365 commit 1c100fc
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
12 changes: 9 additions & 3 deletions fam/llm/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,10 @@ def sample_utterance(
)[0]


def build_models(config_first_stage, config_second_stage, device, use_kv_cache):
smodel = SpeakerEncoder(device=device, eval=True, verbose=False)
def build_models(config_first_stage, config_second_stage, model_dir, device, use_kv_cache):
smodel = SpeakerEncoder(
weights_fpath=os.path.join(model_dir, "speaker_encoder.pt"), device=device, eval=True, verbose=False
)
data_adapter = FlattenedInterleavedEncodec2Codebook(end_of_audio_token=1024)
llm_first_stage = Model(
config_first_stage,
Expand Down Expand Up @@ -690,7 +692,11 @@ class SamplingControllerConfig:

# define models
smodel, llm_first_stage, llm_second_stage = build_models(
config_first_stage, config_second_stage, sampling_config.device, sampling_config.use_kv_cache
config_first_stage,
config_second_stage,
model_dir=model_dir,
device=sampling_config.device,
use_kv_cache=sampling_config.use_kv_cache
)

print(f"Synthesising utterance...")
Expand Down
7 changes: 6 additions & 1 deletion fam/llm/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ class ServingConfig:
enhancer: Optional[Literal["df"]] = "df"
"""Enhancer to use for post-processing."""

compile: bool = False
"""Whether to compile the model using PyTorch 2.0."""

port: int = 58003


Expand Down Expand Up @@ -181,7 +184,9 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
**common_config,
)

spkemb, llm_stg1, llm_stg2 = build_models(config1, config2, device=device, use_kv_cache="flash_decoding")
spkemb, llm_stg1, llm_stg2 = build_models(
config1, config2, model_dir=model_dir, device=device, use_kv_cache="flash_decoding"
)
GlobalState.spkemb_model = spkemb
GlobalState.first_stage_model = llm_stg1
GlobalState.second_stage_model = llm_stg2
Expand Down
Binary file removed fam/quantiser/audio/speaker_encoder/ckpt/ckpt.pt
Binary file not shown.
12 changes: 3 additions & 9 deletions fam/quantiser/audio/speaker_encoder/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

from fam.quantiser.audio.speaker_encoder import audio

DEFAULT_SPKENC_CKPT_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ckpt/ckpt.pt")

mel_window_step = 10
mel_n_channels = 40
sampling_rate = 16000
Expand Down Expand Up @@ -43,14 +41,10 @@ def __init__(
self.device = device

start = timer()
if eval and weights_fpath is None:
weights_fpath = DEFAULT_SPKENC_CKPT_PATH

if weights_fpath is not None:
checkpoint = torch.load(weights_fpath, map_location="cpu")

self.load_state_dict(checkpoint["model_state"], strict=False)
self.to(device)
checkpoint = torch.load(weights_fpath, map_location="cpu")
self.load_state_dict(checkpoint["model_state"], strict=False)
self.to(device)

if eval:
self.eval()
Expand Down

0 comments on commit 1c100fc

Please sign in to comment.