diff --git a/spleeter/__main__.py b/spleeter/__main__.py index 8bbf3b7e..c3a4429e 100644 --- a/spleeter/__main__.py +++ b/spleeter/__main__.py @@ -160,7 +160,7 @@ def separate( EVALUATION_AUDIO_DIRECTORY: str = "audio" -def _compile_metrics(metrics_output_directory) -> Dict: +def _compile_metrics(metrics_output_directory: str) -> Dict: """ Compiles metrics from given directory and returns results as dict. diff --git a/spleeter/audio/adapter.py b/spleeter/audio/adapter.py index 7ecd7386..122e00c3 100644 --- a/spleeter/audio/adapter.py +++ b/spleeter/audio/adapter.py @@ -31,7 +31,7 @@ class AudioAdapter(ABC): """An abstract class for manipulating audio signal.""" _DEFAULT: Optional["AudioAdapter"] = None - """ Default audio adapter singleton instance. """ + """Default audio adapter singleton instance.""" @abstractmethod def load( @@ -51,13 +51,13 @@ def load( Describe song to load, in case of file based audio adapter, such descriptor would be a file path. offset (Optional[float]): - Start offset to load from in seconds. + (Optional) Start offset to load from in seconds. duration (Optional[float]): - Duration to load in seconds. + (Optional) Duration to load in seconds. sample_rate (Optional[float]): - Sample rate to load audio with. + (Optional) Sample rate to load audio with. dtype (bytes): - (Optional)data type to use, default to `b'float32'`. + (Optional) Data type to use, default to `b'float32'`. Returns: Signal: @@ -78,17 +78,17 @@ def load_waveform( Load the audio and convert it to a tensorflow waveform. Parameters: - audio_descriptor (): + audio_descriptor (Any): Describe song to load, in case of file based audio adapter, such descriptor would be a file path. offset (float): - Start offset to load from in seconds. + (Optional) Start offset to load from in seconds. duration (float): - Duration to load in seconds. + (Optional) Duration to load in seconds. sample_rate (float): - Sample rate to load audio with. + (Optional) Sample rate to load audio with. dtype (bytes): - (Optional)data type to use, default to `b'float32'`. + (Optional) Data type to use, default to `b'float32'`. waveform_name (str): (Optional) Name of the key in output dict, default to `'waveform'`. @@ -146,11 +146,11 @@ def save( Parameters: path (Union[Path, str]): Path like of the audio file to save data in. - data (numpy.ndarray): + data (np.ndarray): Waveform data to write. sample_rate (float): Sample rate to write file in. - codec (): + codec (Codec): (Optional) Writing codec to use, default to `None`. bitrate (str): (Optional) Bitrate of the written audio file, default to diff --git a/spleeter/audio/ffmpeg.py b/spleeter/audio/ffmpeg.py index 76a6434b..9932593e 100644 --- a/spleeter/audio/ffmpeg.py +++ b/spleeter/audio/ffmpeg.py @@ -2,10 +2,10 @@ # coding: utf8 """ - This module provides an AudioAdapter implementation based on FFMPEG - process. Such implementation is POSIXish and depends on nothing except - standard Python libraries. Thus this implementation is the default one - used within this library. +This module provides an AudioAdapter implementation based on FFMPEG +process. Such implementation is POSIXish and depends on nothing except +standard Python libraries. Thus this implementation is the default one +used within this library. """ import datetime as dt @@ -77,13 +77,13 @@ def load( path (Union[Path, str]: Path of the audio file to load data from. offset (Optional[float]): - Start offset to load from in seconds. + (Optional) Start offset to load from in seconds. duration (Optional[float]): - Duration to load in seconds. + (Optional) Duration to load in seconds. sample_rate (Optional[float]): - Sample rate to load audio with. + (Optional) Sample rate to load audio with. dtype (bytes): - (Optional)data type to use, default to `b'float32'`. + (Optional) Data type to use, default to `b'float32'`. Returns: Signal: diff --git a/spleeter/audio/spectrogram.py b/spleeter/audio/spectrogram.py index 16c8231a..60ad229b 100644 --- a/spleeter/audio/spectrogram.py +++ b/spleeter/audio/spectrogram.py @@ -30,15 +30,15 @@ def compute_spectrogram_tf( waveform (tf.Tensor): Input waveform as `(times x number of channels)` tensor. frame_length (int): - Length of a STFT frame to use. + (Optional) Length of a STFT frame to use. frame_step (int): - HOP between successive frames. + (Optional) HOP between successive frames. spec_exponent (float): - Exponent of the spectrogram (usually 1 for magnitude - spectrogram, or 2 for power spectrogram). + (Optional) Exponent of the spectrogram (usually 1 for + magnitude spectrogram, or 2 for power spectrogram). window_exponent (float): - Exponent applied to the Hann windowing function (may be - useful for making perfect STFT/iSTFT reconstruction). + (Optional) Exponent applied to the Hann windowing function + (may be useful for making perfect STFT/iSTFT reconstruction). Returns: tf.Tensor: diff --git a/spleeter/dataset.py b/spleeter/dataset.py index 65a65875..e1c98d56 100644 --- a/spleeter/dataset.py +++ b/spleeter/dataset.py @@ -2,23 +2,23 @@ # coding: utf8 """ - Module for building data preprocessing pipeline using the tensorflow - data API. Data preprocessing such as audio loading, spectrogram - computation, cropping, feature caching or data augmentation is done - using a tensorflow dataset object that output a tuple (input_, output) - where: - - - input is a dictionary with a single key that contains the (batched) - mix spectrogram of audio samples - - output is a dictionary of spectrogram of the isolated tracks - (ground truth) +Module for building data preprocessing pipeline using the tensorflow +data API. Data preprocessing such as audio loading, spectrogram +computation, cropping, feature caching or data augmentation is done +using a tensorflow dataset object that output a tuple (input_, output) +where: + +- input is a dictionary with a single key that contains the (batched) + mix spectrogram of audio samples +- output is a dictionary of spectrogram of the isolated tracks + (ground truth) """ import os import time from os.path import exists from os.path import sep as SEPARATOR -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple # pyright: reportMissingImports=false # pylint: disable=import-error @@ -131,14 +131,14 @@ def get_validation_dataset( class InstrumentDatasetBuilder(object): """Instrument based filter and mapper provider.""" - def __init__(self, parent, instrument) -> None: + def __init__(self, parent: Any, instrument: Any) -> None: """ Default constructor. Parameters: - parent: + parent (Any): Parent dataset builder. - instrument: + instrument (Any): Target instrument. """ self._parent = parent @@ -147,7 +147,7 @@ def __init__(self, parent, instrument) -> None: self._min_spectrogram_key = f"min_{instrument}_spectrogram" self._max_spectrogram_key = f"max_{instrument}_spectrogram" - def load_waveform(self, sample): + def load_waveform(self, sample: Dict) -> Dict: """Load waveform for given sample.""" return dict( sample, @@ -160,7 +160,7 @@ def load_waveform(self, sample): ), ) - def compute_spectrogram(self, sample): + def compute_spectrogram(self, sample: Dict) -> Dict: """Compute spectrogram of the given sample.""" return dict( sample, @@ -175,7 +175,7 @@ def compute_spectrogram(self, sample): }, ) - def filter_frequencies(self, sample): + def filter_frequencies(self, sample: Dict) -> Dict: return dict( sample, **{ @@ -185,7 +185,7 @@ def filter_frequencies(self, sample): }, ) - def convert_to_uint(self, sample): + def convert_to_uint(self, sample: Dict) -> Dict: """Convert given sample from float to unit.""" return dict( sample, @@ -197,11 +197,11 @@ def convert_to_uint(self, sample): ), ) - def filter_infinity(self, sample): + def filter_infinity(self, sample: Dict) -> tf.Tensor: """Filter infinity sample.""" return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key])) - def convert_to_float32(self, sample): + def convert_to_float32(self, sample: Dict) -> Dict: """Convert given sample from unit to float.""" return dict( sample, @@ -214,7 +214,7 @@ def convert_to_float32(self, sample): }, ) - def time_crop(self, sample): + def time_crop(self, sample: Dict) -> Dict: def start(sample): """mid_segment_start""" return tf.cast( @@ -235,14 +235,14 @@ def start(sample): }, ) - def filter_shape(self, sample): + def filter_shape(self, sample: Dict) -> bool: """Filter badly shaped sample.""" return check_tensor_shape( sample[self._spectrogram_key], (self._parent._T, self._parent._F, self._parent._n_channels), ) - def reshape_spectrogram(self, sample): + def reshape_spectrogram(self, sample: Dict) -> Dict: """Reshape given sample.""" return dict( sample, @@ -272,17 +272,6 @@ def __init__( ) -> None: """ Default constructor. - - NOTE: Probably need for AudioAdapter. - - Parameters: - audio_params (Dict): - Audio parameters to use. - audio_adapter (AudioAdapter): - Audio adapter to use. - audio_path (str): - random_seed (int): - chunk_duration (float): """ # Length of segment in frames (if fs=22050 and # frame_step=512, then T=512 corresponds to 11.89s) @@ -321,7 +310,7 @@ def check_parameters_compatibility(self): "(for instance reducing T or frame_step or increasing chunk duration)." ) - def expand_path(self, sample): + def expand_path(self, sample: Dict) -> Dict: """Expands audio paths for the given sample.""" return dict( sample, @@ -333,15 +322,15 @@ def expand_path(self, sample): }, ) - def filter_error(self, sample): + def filter_error(self, sample: Dict) -> tf.Tensor: """Filter errored sample.""" return tf.logical_not(sample["waveform_error"]) - def filter_waveform(self, sample): + def filter_waveform(self, sample: Dict) -> Dict: """Filter waveform from sample.""" return {k: v for k, v in sample.items() if not k == "waveform"} - def harmonize_spectrogram(self, sample): + def harmonize_spectrogram(self, sample: Dict) -> Dict: """Ensure same size for vocals and mix spectrograms.""" def _reduce(sample): @@ -362,7 +351,7 @@ def _reduce(sample): }, ) - def filter_short_segments(self, sample): + def filter_short_segments(self, sample: Dict) -> tf.Tensor: """Filter out too short segment.""" return tf.reduce_any( [ @@ -371,7 +360,7 @@ def filter_short_segments(self, sample): ] ) - def random_time_crop(self, sample): + def random_time_crop(self, sample: Dict) -> Dict: """Random time crop of 11.88s.""" return dict( sample, @@ -388,7 +377,7 @@ def random_time_crop(self, sample): ), ) - def random_time_stretch(self, sample): + def random_time_stretch(self, sample: Dict) -> Dict: """Randomly time stretch the given sample.""" return dict( sample, @@ -401,7 +390,7 @@ def random_time_stretch(self, sample): ), ) - def random_pitch_shift(self, sample): + def random_pitch_shift(self, sample: Dict) -> Dict: """Randomly pitch shift the given sample.""" return dict( sample, @@ -415,7 +404,7 @@ def random_pitch_shift(self, sample): ), ) - def map_features(self, sample): + def map_features(self, sample: Dict) -> Tuple[Dict, Dict]: """Select features and annotation of the given sample.""" input_ = { f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"] diff --git a/spleeter/model/__init__.py b/spleeter/model/__init__.py index 670448f9..eda70a16 100644 --- a/spleeter/model/__init__.py +++ b/spleeter/model/__init__.py @@ -4,6 +4,7 @@ """ This package provide an estimator builder as well as model functions. """ import importlib +from typing import Any, Dict, Optional, Tuple # pyright: reportMissingImports=false # pylint: disable=import-error @@ -125,7 +126,7 @@ class EstimatorSpecBuilder(object): WINDOW_COMPENSATION_FACTOR = 2.0 / 3.0 EPSILON = 1e-10 - def __init__(self, features, params): + def __init__(self, features: Dict, params: Dict) -> None: """ Default constructor. Depending on built model usage, the provided features should be different: @@ -162,10 +163,6 @@ def _build_model_outputs(self): mix magnitude spectrogram, then an output dict from it according to the selected model in internal parameters. - Returns: - Dict: - Build output dict. - Raises: ValueError: If required model_type is not supported. @@ -184,20 +181,17 @@ def _build_model_outputs(self): input_tensor, self._instruments, self._params["model"]["params"] ) - def _build_loss(self, labels): + def _build_loss(self, labels: Dict) -> Tuple[tf.Tensor, Dict]: """ Construct tensorflow loss and metrics Parameters: - output_dict (Dict): - Dictionary of network outputs (key: instrument name, - value: estimated spectrogram of the instrument) labels (Dict): Dictionary of target outputs (key: instrument name, value: ground truth spectrogram of the instrument) Returns: - Tuple[tf.Tensor, Dict[str, tf.Tensor]]: + Tuple[tf.Tensor, Dict]: Tensorflow (loss, metrics) tuple. """ output_dict = self.model_outputs @@ -223,7 +217,7 @@ def _build_loss(self, labels): metrics["absolute_difference"] = tf.compat.v1.metrics.mean(loss) return loss, metrics - def _build_optimizer(self): + def _build_optimizer(self) -> tf.Tensor: """ Builds an optimizer instance from internal parameter values. Default to AdamOptimizer if not specified. @@ -323,13 +317,18 @@ def masked_stfts(self): self._build_masked_stfts() return self._masked_stfts - def _inverse_stft(self, stft_t, time_crop=None): + def _inverse_stft( + self, stft_t: tf.Tensor, time_crop: Optional[Any] = None + ) -> tf.Tensor: """ Inverse and reshape the given STFT Parameters: stft_t (tf.Tensor): Input STFT. + time_crop (Optional[Any]): + Time cropping. + Returns: tf.Tensor: Inverse STFT (waveform). @@ -350,7 +349,7 @@ def _inverse_stft(self, stft_t, time_crop=None): time_crop = tf.shape(self._features["waveform"])[0] return reshaped[self._frame_length : self._frame_length + time_crop, :] - def _build_mwf_output_waveform(self): + def _build_mwf_output_waveform(self) -> Dict: """ Perform separation with multichannel Wiener Filtering using Norbert. @@ -390,7 +389,7 @@ def _build_mwf_output_waveform(self): for k, instrument in enumerate(self._instruments) } - def _extend_mask(self, mask): + def _extend_mask(self, mask: tf.Tensor) -> tf.Tensor: """ Extend mask, from reduced number of frequency bin to the number of frequency bin in the STFT. @@ -464,7 +463,7 @@ def _build_masked_stfts(self): out[instrument] = tf.cast(mask, dtype=tf.complex64) * input_stft self._masked_stfts = out - def _build_manual_output_waveform(self, masked_stft): + def _build_manual_output_waveform(self, masked_stft: Dict) -> Dict: """ Perform ratio mask separation @@ -483,7 +482,7 @@ def _build_manual_output_waveform(self, masked_stft): output_waveform[instrument] = self._inverse_stft(stft_data) return output_waveform - def _build_output_waveform(self, masked_stft): + def _build_output_waveform(self, masked_stft: Dict) -> Dict: """ Build output waveform from given output dict in order to be used in prediction context. The configuration @@ -510,7 +509,7 @@ def _build_outputs(self): if "audio_id" in self._features: self._outputs["audio_id"] = self._features["audio_id"] - def build_predict_model(self): + def build_predict_model(self) -> tf.Tensor: """ Builder interface for creating model instance that aims to perform prediction / inference over given track. The output of such estimator @@ -526,7 +525,7 @@ def build_predict_model(self): tf.estimator.ModeKeys.PREDICT, predictions=self.outputs ) - def build_evaluation_model(self, labels): + def build_evaluation_model(self, labels: Dict) -> tf.Tensor: """ Builder interface for creating model instance that aims to perform model evaluation. The output of such estimator @@ -535,7 +534,7 @@ def build_evaluation_model(self, labels): separated instrument magnitude spectrogram. Parameters: - labels (): + labels (Dict): Model labels. Returns: @@ -547,7 +546,7 @@ def build_evaluation_model(self, labels): tf.estimator.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics ) - def build_train_model(self, labels): + def build_train_model(self, labels: Dict) -> tf.Tensor: """ Builder interface for creating model instance that aims to perform model training. The output of such estimator will be a dictionary @@ -555,7 +554,7 @@ def build_train_model(self, labels): associated to the estimated separated instrument magnitude spectrogram. Parameters: - labels (): + labels (Dict): Model labels. Returns: diff --git a/spleeter/model/functions/__init__.py b/spleeter/model/functions/__init__.py index b8496e87..77c5eec9 100644 --- a/spleeter/model/functions/__init__.py +++ b/spleeter/model/functions/__init__.py @@ -32,7 +32,7 @@ def apply( Tensor to apply blstm to. instruments (Iterable[str]): Iterable that provides a collection of instruments. - params (Dict): + params (Optional[Dict]): (Optional) dict of BLSTM parameters. Returns: diff --git a/spleeter/model/functions/blstm.py b/spleeter/model/functions/blstm.py index d9baf6a3..550b19da 100644 --- a/spleeter/model/functions/blstm.py +++ b/spleeter/model/functions/blstm.py @@ -55,7 +55,7 @@ def apply_blstm( Input of the model. output_name (str): (Optional) name of the output, default to 'output'. - params (Dict): + params (Optional[Dict]): (Optional) dict of BLSTM parameters. Returns: diff --git a/spleeter/model/functions/unet.py b/spleeter/model/functions/unet.py index b08e7f1d..6c304bb5 100644 --- a/spleeter/model/functions/unet.py +++ b/spleeter/model/functions/unet.py @@ -45,10 +45,9 @@ def _get_conv_activation_layer(params: Dict) -> Any: """ - > To be documented. - Parameters: params (Dict): + Model parameters. Returns: Any: @@ -64,10 +63,9 @@ def _get_conv_activation_layer(params: Dict) -> Any: def _get_deconv_activation_layer(params: Dict) -> Any: """ - > To be documented. - Parameters: params (Dict): + Model parameters. Returns: Any: @@ -86,16 +84,24 @@ def apply_unet( output_name: str = "output", params: Dict = {}, output_mask_logit: bool = False, -) -> Any: +) -> tf.Tensor: """ Apply a convolutionnal U-net to model a single instrument (one U-net is used for each instrument). Parameters: input_tensor (tf.Tensor): + Input of the model. output_name (str): - params (Optional[Dict]): + (Optional) name of the output, default to 'output'. + params (Dict): + (Optional) dict of BLSTM parameters. output_mask_logit (bool): + (Optional) Sigmoid or logit? + + Returns: + tf.Tensor: + Output tensor. """ logging.info(f"Apply unet for {output_name}") conv_n_filters = params.get("conv_n_filters", [16, 32, 64, 128, 256, 512]) diff --git a/spleeter/separator.py b/spleeter/separator.py index c11f6445..adef3987 100644 --- a/spleeter/separator.py +++ b/spleeter/separator.py @@ -2,16 +2,16 @@ # coding: utf8 """ - Module that provides a class wrapper for source separation. +Module that provides a class wrapper for source separation. - Examples: +Examples: - ```python - >>> from spleeter.separator import Separator - >>> separator = Separator('spleeter:2stems') - >>> separator.separate(waveform, lambda instrument, data: ...) - >>> separator.separate_to_file(...) - ``` +```python +>>> from spleeter.separator import Separator +>>> separator = Separator('spleeter:2stems') +>>> separator.separate(waveform, lambda instrument, data: ...) +>>> separator.separate_to_file(...) +``` """ import atexit @@ -64,13 +64,15 @@ def __call__(self) -> Generator: buffer = self._current_data -def create_estimator(params, MWF): +def create_estimator(params: Dict, MWF: bool) -> tf.Tensor: """ Initialize tensorflow estimator that will perform separation Parameters: params (Dict): A dictionary of parameters for building the model + MWF (bool): + Wiener filter enabled? Returns: tf.Tensor: @@ -107,9 +109,9 @@ def __init__( params_descriptor (str): Descriptor for TF params to be used. MWF (bool): - `True` if MWF should be used, `False` otherwise. + (Optional) `True` if MWF should be used, `False` otherwise. multiprocess (bool): - Enable multi-processing. + (Optional) Enable multi-processing. """ self._params = load_configuration(params_descriptor) self._sample_rate = self._params["sample_rate"] @@ -158,7 +160,7 @@ def join(self, timeout: int = 200) -> None: Parameters: timeout (int): - Task waiting timeout. + (Optional) Task waiting timeout. """ while len(self._tasks) > 0: task = self._tasks.pop() @@ -227,9 +229,9 @@ def separate( Performs separation on a waveform. Parameters: - waveform (numpy.ndarray): + waveform (np.ndarray): Waveform to be separated (as a numpy array) - audio_descriptor (str): + audio_descriptor (Optional[str]): (Optional) string describing the waveform (e.g. filename). Returns: diff --git a/spleeter/utils/logging.py b/spleeter/utils/logging.py index b38d7e02..bc028807 100644 --- a/spleeter/utils/logging.py +++ b/spleeter/utils/logging.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # coding: utf8 -""" Centralized logging facilities for Spleeter. """ +"""Centralized logging facilities for Spleeter.""" import logging import warnings diff --git a/spleeter/utils/tensor.py b/spleeter/utils/tensor.py index 22f6d973..ba2d2874 100644 --- a/spleeter/utils/tensor.py +++ b/spleeter/utils/tensor.py @@ -40,7 +40,7 @@ def sync_apply( Function to be applied to the concatenation of the tensors in `tensor_dict`. concat_axis (int): - The axis on which to perform the concatenation. + (Optional) The axis on which to perform the concatenation. Returns: Dict[str, tf.Tensor]: