Skip to content

Commit

Permalink
Merge pull request #740 from AznamirWoW/ZludaRefactoring
Browse files Browse the repository at this point in the history
Re-worked Zluda business in a much simpler way
  • Loading branch information
blaisewf authored Sep 27, 2024
2 parents 39a38b2 + 537dfc7 commit 268a81f
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 97 deletions.
3 changes: 3 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
now_dir = os.getcwd()
sys.path.append(now_dir)

# Zluda hijack
import rvc.lib.zluda

# Import Tabs
from tabs.inference.inference import inference_tab
from tabs.train.train import train_tab
Expand Down
7 changes: 0 additions & 7 deletions rvc/configs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,6 @@ def device_config(self) -> tuple:
def set_cuda_config(self):
i_device = int(self.device.split(":")[-1])
self.gpu_name = torch.cuda.get_device_name(i_device)
# Zluda
if self.gpu_name.endswith("[ZLUDA]"):
print("Zluda compatibility enabled, experimental feature.")
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"]
if (
any(gpu in self.gpu_name for gpu in low_end_gpus)
Expand Down
19 changes: 0 additions & 19 deletions rvc/lib/algorithm/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,25 +156,6 @@ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
acts = t_act * s_act
return acts


# Zluda, same as previous, but without jit.script
def fused_add_tanh_sigmoid_multiply_no_jit(input_a, input_b, n_channels):
"""
Fused add tanh sigmoid multiply operation.
Args:
input_a: The first input tensor.
input_b: The second input tensor.
n_channels: The number of channels.
"""
n_channels_int = n_channels[0]
in_act = input_a + input_b
t_act = torch.tanh(in_act[:, :n_channels_int, :])
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
acts = t_act * s_act
return acts


def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
"""
Convert the pad shape to a list of integers.
Expand Down
20 changes: 2 additions & 18 deletions rvc/lib/algorithm/modules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
import torch
from rvc.lib.algorithm.commons import (
fused_add_tanh_sigmoid_multiply_no_jit,
fused_add_tanh_sigmoid_multiply,
)

from rvc.lib.algorithm.commons import fused_add_tanh_sigmoid_multiply

class WaveNet(torch.nn.Module):
"""WaveNet residual blocks as used in WaveGlow
Expand Down Expand Up @@ -88,11 +84,6 @@ def forward(self, x, x_mask, g=None, **kwargs):
if g is not None:
g = self.cond_layer(g)

# Zluda
is_zluda = x.device.type == "cuda" and torch.cuda.get_device_name().endswith(
"[ZLUDA]"
)

for i in range(self.n_layers):
x_in = self.in_layers[i](x)
if g is not None:
Expand All @@ -101,14 +92,7 @@ def forward(self, x, x_mask, g=None, **kwargs):
else:
g_l = torch.zeros_like(x_in)

# Preventing HIP crash by not using jit-decorated function
if is_zluda:
acts = fused_add_tanh_sigmoid_multiply_no_jit(
x_in, g_l, n_channels_tensor
)
else:
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)

acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
acts = self.drop(acts)

res_skip_acts = self.res_skip_layers[i](acts)
Expand Down
25 changes: 10 additions & 15 deletions rvc/lib/predictors/F0Extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,30 @@ def wav16k(self) -> np.ndarray:
def extract_f0(self) -> np.ndarray:
f0 = None
method = self.method
# Fall back to CPU for ZLUDA as these methods use CUcFFT
device = (
"cpu"
if "cuda" in config.device
and torch.cuda.get_device_name().endswith("[ZLUDA]")
else config.device
)

if method == "crepe":
wav16k_torch = torch.FloatTensor(self.wav16k).unsqueeze(0).to(device)
wav16k_torch = torch.FloatTensor(self.wav16k).unsqueeze(0).to(config.device)
f0 = torchcrepe.predict(
wav16k_torch,
sample_rate=16000,
hop_length=160,
batch_size=512,
fmin=self.f0_min,
fmax=self.f0_max,
device=device,
device=config.device,
)
f0 = f0[0].cpu().numpy()
elif method == "fcpe":
audio = librosa.to_mono(self.x)
audio_length = len(audio)
f0_target_length = (audio_length // self.hop_length) + 1
audio = (
torch.from_numpy(audio).float().unsqueeze(0).unsqueeze(-1).to(device)
torch.from_numpy(audio)
.float()
.unsqueeze(0)
.unsqueeze(-1)
.to(config.device)
)
model = torchfcpe.spawn_bundled_infer_model(device=device)
model = torchfcpe.spawn_bundled_infer_model(device=config.device)

f0 = model.infer(
audio,
Expand All @@ -81,11 +77,10 @@ def extract_f0(self) -> np.ndarray:
)
f0 = f0.squeeze().cpu().numpy()
elif method == "rmvpe":
is_half = False if device == "cpu" else config.is_half
model_rmvpe = RMVPE0Predictor(
os.path.join("rvc", "models", "predictors", "rmvpe.pt"),
is_half=is_half,
device=device,
is_half=config.is_half,
device=config.device,
# hop_length=80
)
f0 = model_rmvpe.infer_from_audio(self.wav16k, thred=0.03)
Expand Down
8 changes: 1 addition & 7 deletions rvc/lib/predictors/FCPE.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,6 @@ def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode)
y = y.squeeze(1)

# Zluda, fall-back to CPU for FFTs since HIP SDK has no cuFFT alternative
source_device = y.device
if y.device.type == "cuda" and torch.cuda.get_device_name().endswith("[ZLUDA]"):
y = y.to("cpu")
hann_window[keyshift_key] = hann_window[keyshift_key].to("cpu")

spec = torch.stft(
y,
n_fft_new,
Expand All @@ -156,7 +150,7 @@ def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
normalized=False,
onesided=True,
return_complex=True,
).to(source_device)
)
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))

# Handle keyshift and mel conversion
Expand Down
11 changes: 1 addition & 10 deletions rvc/lib/predictors/RMVPE.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,6 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
audio.device
)

# Zluda, fall-back to CPU for FFTs since HIP SDK has no cuFFT alternative
source_device = audio.device
if audio.device.type == "cuda" and torch.cuda.get_device_name().endswith(
"[ZLUDA]"
):
audio = audio.to("cpu")
self.hann_window[keyshift_key] = self.hann_window[keyshift_key].to("cpu")

fft = torch.stft(
audio,
n_fft=n_fft_new,
Expand All @@ -425,7 +416,7 @@ def forward(self, audio, keyshift=0, speed=1, center=True):
window=self.hann_window[keyshift_key],
center=center,
return_complex=True,
).to(source_device)
)

magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
if keyshift != 0:
Expand Down
41 changes: 41 additions & 0 deletions rvc/lib/zluda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
_torch_stft = torch.stft
def z_stft(
audio: torch.Tensor,
n_fft: int,
hop_length: int = None,
win_length: int = None,
window: torch.Tensor = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool = None,
return_complex: bool = None,
):
sd = audio.device
return _torch_stft(
audio.to("cpu"),
n_fft=n_fft,
hop_length=hop_length,
win_length=win_length,
window=window.to("cpu"),
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided,
return_complex=return_complex,
).to(sd)

def z_jit(f, *_, **__):
f.graph = torch._C.Graph()
return f

# hijacks
torch.stft = z_stft
torch.jit.script = z_jit
# disabling unsupported cudnn
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
10 changes: 3 additions & 7 deletions rvc/train/extract/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@
import shutil
from distutils.util import strtobool

# Zluda
if torch.cuda.is_available() and torch.cuda.get_device_name().endswith("[ZLUDA]"):
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)

now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))

# Zluda hijack
import rvc.lib.zluda

from rvc.lib.utils import load_audio, load_embedding
from rvc.train.extract.preparing_files import generate_config, generate_filelist
from rvc.lib.predictors.RMVPE import RMVPE0Predictor
Expand Down
8 changes: 1 addition & 7 deletions rvc/train/mel_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
)
y = y.squeeze(1)

# Zluda, fall-back to CPU for FFTs since HIP SDK has no cuFFT alternative
source_device = y.device
if y.device.type == "cuda" and torch.cuda.get_device_name().endswith("[ZLUDA]"):
y = y.to("cpu")
hann_window[wnsize_dtype_device] = hann_window[wnsize_dtype_device].to("cpu")

spec = torch.stft(
y,
n_fft,
Expand All @@ -93,7 +87,7 @@ def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
normalized=False,
onesided=True,
return_complex=True,
).to(source_device)
)

spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)

Expand Down
10 changes: 3 additions & 7 deletions rvc/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
now_dir = os.getcwd()
sys.path.append(os.path.join(now_dir))

# Zluda hijack
import rvc.lib.zluda

from utils import (
HParams,
plot_spectrogram_to_numpy,
Expand Down Expand Up @@ -375,12 +378,6 @@ def run(

if torch.cuda.is_available():
torch.cuda.set_device(rank)
if torch.cuda.get_device_name().endswith("[ZLUDA]"):
print("Disabling CUDNN for traning with Zluda")
torch.backends.cudnn.enabled = False
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)

# Create datasets and dataloaders
train_dataset = TextAudioLoaderMultiNSFsid(config.data)
Expand Down Expand Up @@ -1027,7 +1024,6 @@ def save_to_json(
with open(file_path, "w") as f:
json.dump(data, f)


if __name__ == "__main__":
torch.multiprocessing.set_start_method("spawn")
main()

0 comments on commit 268a81f

Please sign in to comment.