forked from metavoiceio/metavoice-src
-
Notifications
You must be signed in to change notification settings - Fork 0
/
enhancers.py
108 lines (73 loc) · 3.14 KB
/
enhancers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
from abc import ABC
from typing import Literal, Optional
from df.enhance import enhance, init_df, load_audio, save_audio
from pydub import AudioSegment
def convert_to_wav(input_file: str, output_file: str):
"""Convert an audio file to WAV format
Args:
input_file (str): path to input audio file
output_file (str): path to output WAV file
"""
# Detect the format of the input file
format = input_file.split(".")[-1].lower()
# Read the audio file
audio = AudioSegment.from_file(input_file, format=format)
# Export as WAV
audio.export(output_file, format="wav")
def make_output_file_path(audio_file: str, tag: str, ext: Optional[str] = None) -> str:
"""Generate the output file path
Args:
audio_file (str): path to input audio file
tag (str): tag to append to the output file name
ext (str, optional): extension of the output file. Defaults to None.
Returns:
str: path to output file
"""
directory = "./enhanced"
# Get the name of the input file
filename = os.path.basename(audio_file)
# Get the name of the input file without the extension
filename_without_extension = os.path.splitext(filename)[0]
# Get the extension of the input file
extension = ext or os.path.splitext(filename)[1]
# Generate the output file path
output_file = os.path.join(directory, filename_without_extension + tag + extension)
return output_file
class BaseEnhancer(ABC):
"""Base class for audio enhancers"""
def __init__(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
raise NotImplementedError
def get_output_file(self, audio_file: str, tag: str, ext: Optional[str] = None) -> str:
output_file = make_output_file_path(audio_file, tag, ext=ext)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
return output_file
class DFEnhancer(BaseEnhancer):
def __init__(self, *args, **kwargs):
self.model, self.df_state, _ = init_df()
def __call__(self, audio_file: str, output_file: Optional[str] = None) -> str:
output_file = output_file or self.get_output_file(audio_file, "_df")
audio, _ = load_audio(audio_file, sr=self.df_state.sr())
enhanced = enhance(self.model, self.df_state, audio)
save_audio(output_file, enhanced, self.df_state.sr())
return output_file
def get_enhancer(enhancer_name: Literal["df"]) -> BaseEnhancer:
"""Get an audio enhancer
Args:
enhancer_name (Literal["df"]): name of the audio enhancer
Raises:
ValueError: if the enhancer name is not recognised
Returns:
BaseEnhancer: audio enhancer
"""
if enhancer_name == "df":
import warnings
warnings.filterwarnings(
"ignore",
message='"sinc_interpolation" resampling method name is being deprecated and replaced by "sinc_interp_hann" in the next release. The default behavior remains unchanged.',
)
return DFEnhancer()
else:
raise ValueError(f"Unknown enhancer name: {enhancer_name}")