Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow deterministic generations #175

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bark/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt, set_seed
from .generation import SAMPLE_RATE, preload_models
45 changes: 45 additions & 0 deletions bark/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Optional

import numpy as np
import torch
import random
import os

from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic

Expand Down Expand Up @@ -83,6 +86,48 @@ def save_as_prompt(filepath, full_generation):
np.savez(filepath, **full_generation)


def set_seed(seed: int = 0):
"""Set the seed

seed = 0 Generate a random seed
seed = -1 Disable deterministic algorithms
0 < seed < 2**32 Set the seed

Args:
seed: integer to use as seed

Returns:
integer used as seed
"""

original_seed = seed

# See for more informations: https://pytorch.org/docs/stable/notes/randomness.html
if seed == -1:
# Disable deterministic
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
else:
# Enable deterministic
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if seed <= 0:
# Generate random seed
# Use default_rng() because it is independent of np.random.seed()
seed = np.random.default_rng().integers(1, 2**32 - 1)

assert(0 < seed and seed < 2**32)

np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

return original_seed if original_seed != 0 else seed


def generate_audio(
text: str,
history_prompt: Optional[str] = None,
Expand Down
49 changes: 25 additions & 24 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ def _grab_best_device(use_gpu=True):
device = "cpu"
return device

def _load_history_prompt(history_prompt, required_keys=[]):
x_history = None
if history_prompt is not None:
if getattr(history_prompt, '__getitem__') and not isinstance(history_prompt, str):
x_history = history_prompt
elif history_prompt.endswith(".npz"):
x_history = np.load(history_prompt)
else:
assert (history_prompt in ALLOWED_PROMPTS)
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)

for key in required_keys:
assert(key in x_history)

return x_history

S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/"

Expand Down Expand Up @@ -404,14 +421,9 @@ def generate_text_semantic(
assert isinstance(text, str)
text = _normalize_whitespace(text)
assert len(text.strip()) > 0
if history_prompt is not None:
if history_prompt.endswith(".npz"):
semantic_history = np.load(history_prompt)["semantic_prompt"]
else:
assert (history_prompt in ALLOWED_PROMPTS)
semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"]
x_history = _load_history_prompt(history_prompt, ["semantic_prompt"])
if x_history is not None:
semantic_history = x_history["semantic_prompt"]
assert (
isinstance(semantic_history, np.ndarray)
and len(semantic_history.shape) == 1
Expand Down Expand Up @@ -573,14 +585,8 @@ def generate_coarse(
assert max_coarse_history + sliding_window_len <= 1024 - 256
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
if history_prompt is not None:
if history_prompt.endswith(".npz"):
x_history = np.load(history_prompt)
else:
assert (history_prompt in ALLOWED_PROMPTS)
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)
x_history = _load_history_prompt(history_prompt, ["semantic_prompt", "coarse_prompt"])
if x_history is not None:
x_semantic_history = x_history["semantic_prompt"]
x_coarse_history = x_history["coarse_prompt"]
assert (
Expand Down Expand Up @@ -738,14 +744,9 @@ def generate_fine(
and x_coarse_gen.min() >= 0
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
)
if history_prompt is not None:
if history_prompt.endswith(".npz"):
x_fine_history = np.load(history_prompt)["fine_prompt"]
else:
assert (history_prompt in ALLOWED_PROMPTS)
x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"]
x_history = _load_history_prompt(history_prompt, ["fine_prompt"])
if x_history is not None:
x_fine_history = x_history["fine_prompt"]
assert (
isinstance(x_fine_history, np.ndarray)
and len(x_fine_history.shape) == 2
Expand Down