Skip to content

Commit b839200

Browse files
authored
Merge branch 'dev' into supprot_aysnc_vllm
2 parents d1a117b + 8fcc0cd commit b839200

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

ChatTTS/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def _load(
273273
vq_config=asdict(self.config.dvae.vq),
274274
dim=self.config.dvae.decoder.idim,
275275
coef=coef,
276-
device=self.device,
276+
device=device,
277277
)
278278
.to(device)
279279
.eval()
@@ -290,8 +290,8 @@ def _load(
290290
self.config.embed.num_text_tokens,
291291
self.config.embed.num_vq,
292292
)
293-
embed.from_pretrained(embed_path, device=self.device)
294-
self.embed = embed.to(self.device)
293+
embed.from_pretrained(embed_path, device=device)
294+
self.embed = embed.to(device)
295295
self.logger.log(logging.INFO, "embed loaded.")
296296

297297
gpt = GPT(
@@ -319,6 +319,7 @@ def _load(
319319
decoder_config=asdict(self.config.decoder),
320320
dim=self.config.decoder.idim,
321321
coef=coef,
322+
device=device,
322323
)
323324
.to(device)
324325
.eval()

ChatTTS/model/dvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def __init__(
179179
hop_length=256,
180180
n_mels=100,
181181
padding: Literal["center", "same"] = "center",
182-
device: torch.device = torch.device("cuda"),
182+
device: torch.device = torch.device("cpu"),
183183
):
184184
super().__init__()
185185
self.device = device
@@ -213,7 +213,7 @@ def __init__(
213213
vq_config: Optional[dict] = None,
214214
dim=512,
215215
coef: Optional[str] = None,
216-
device: torch.device = torch.device("cuda"),
216+
device: torch.device = torch.device("cpu"),
217217
):
218218
super().__init__()
219219
if coef is None:

examples/ipynb/colab.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@
355355
"metadata": {},
356356
"outputs": [],
357357
"source": [
358-
"from tools.audio import load_audio\n",
358+
"from ChatTTS.tools.audio import load_audio\n",
359359
"\n",
360360
"spk_smp = chat.sample_audio_speaker(load_audio(\"sample.mp3\", 24000))\n",
361361
"print(spk_smp) # save it in order to load the speaker without sample audio next time\n",

0 commit comments

Comments
 (0)