Skip to content

Commit

Permalink
Last formatting.
Browse files Browse the repository at this point in the history
  • Loading branch information
d-dawg78 committed Feb 20, 2023
1 parent 19b523f commit 5adc8cf
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 115 deletions.
2 changes: 1 addition & 1 deletion spleeter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def separate(
EVALUATION_AUDIO_DIRECTORY: str = "audio"


def _compile_metrics(metrics_output_directory) -> Dict:
def _compile_metrics(metrics_output_directory: str) -> Dict:
"""
Compiles metrics from given directory and returns results as dict.
Expand Down
24 changes: 12 additions & 12 deletions spleeter/audio/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class AudioAdapter(ABC):
"""An abstract class for manipulating audio signal."""

_DEFAULT: Optional["AudioAdapter"] = None
""" Default audio adapter singleton instance. """
"""Default audio adapter singleton instance."""

@abstractmethod
def load(
Expand All @@ -51,13 +51,13 @@ def load(
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (Optional[float]):
Start offset to load from in seconds.
(Optional) Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
(Optional) Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
(Optional) Sample rate to load audio with.
dtype (bytes):
(Optional)data type to use, default to `b'float32'`.
(Optional) Data type to use, default to `b'float32'`.
Returns:
Signal:
Expand All @@ -78,17 +78,17 @@ def load_waveform(
Load the audio and convert it to a tensorflow waveform.
Parameters:
audio_descriptor ():
audio_descriptor (Any):
Describe song to load, in case of file based audio adapter,
such descriptor would be a file path.
offset (float):
Start offset to load from in seconds.
(Optional) Start offset to load from in seconds.
duration (float):
Duration to load in seconds.
(Optional) Duration to load in seconds.
sample_rate (float):
Sample rate to load audio with.
(Optional) Sample rate to load audio with.
dtype (bytes):
(Optional)data type to use, default to `b'float32'`.
(Optional) Data type to use, default to `b'float32'`.
waveform_name (str):
(Optional) Name of the key in output dict, default to
`'waveform'`.
Expand Down Expand Up @@ -146,11 +146,11 @@ def save(
Parameters:
path (Union[Path, str]):
Path like of the audio file to save data in.
data (numpy.ndarray):
data (np.ndarray):
Waveform data to write.
sample_rate (float):
Sample rate to write file in.
codec ():
codec (Codec):
(Optional) Writing codec to use, default to `None`.
bitrate (str):
(Optional) Bitrate of the written audio file, default to
Expand Down
16 changes: 8 additions & 8 deletions spleeter/audio/ffmpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# coding: utf8

"""
This module provides an AudioAdapter implementation based on FFMPEG
process. Such implementation is POSIXish and depends on nothing except
standard Python libraries. Thus this implementation is the default one
used within this library.
This module provides an AudioAdapter implementation based on FFMPEG
process. Such implementation is POSIXish and depends on nothing except
standard Python libraries. Thus this implementation is the default one
used within this library.
"""

import datetime as dt
Expand Down Expand Up @@ -77,13 +77,13 @@ def load(
path (Union[Path, str]:
Path of the audio file to load data from.
offset (Optional[float]):
Start offset to load from in seconds.
(Optional) Start offset to load from in seconds.
duration (Optional[float]):
Duration to load in seconds.
(Optional) Duration to load in seconds.
sample_rate (Optional[float]):
Sample rate to load audio with.
(Optional) Sample rate to load audio with.
dtype (bytes):
(Optional)data type to use, default to `b'float32'`.
(Optional) Data type to use, default to `b'float32'`.
Returns:
Signal:
Expand Down
12 changes: 6 additions & 6 deletions spleeter/audio/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ def compute_spectrogram_tf(
waveform (tf.Tensor):
Input waveform as `(times x number of channels)` tensor.
frame_length (int):
Length of a STFT frame to use.
(Optional) Length of a STFT frame to use.
frame_step (int):
HOP between successive frames.
(Optional) HOP between successive frames.
spec_exponent (float):
Exponent of the spectrogram (usually 1 for magnitude
spectrogram, or 2 for power spectrogram).
(Optional) Exponent of the spectrogram (usually 1 for
magnitude spectrogram, or 2 for power spectrogram).
window_exponent (float):
Exponent applied to the Hann windowing function (may be
useful for making perfect STFT/iSTFT reconstruction).
(Optional) Exponent applied to the Hann windowing function
(may be useful for making perfect STFT/iSTFT reconstruction).
Returns:
tf.Tensor:
Expand Down
75 changes: 32 additions & 43 deletions spleeter/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
# coding: utf8

"""
Module for building data preprocessing pipeline using the tensorflow
data API. Data preprocessing such as audio loading, spectrogram
computation, cropping, feature caching or data augmentation is done
using a tensorflow dataset object that output a tuple (input_, output)
where:
- input is a dictionary with a single key that contains the (batched)
mix spectrogram of audio samples
- output is a dictionary of spectrogram of the isolated tracks
(ground truth)
Module for building data preprocessing pipeline using the tensorflow
data API. Data preprocessing such as audio loading, spectrogram
computation, cropping, feature caching or data augmentation is done
using a tensorflow dataset object that output a tuple (input_, output)
where:
- input is a dictionary with a single key that contains the (batched)
mix spectrogram of audio samples
- output is a dictionary of spectrogram of the isolated tracks
(ground truth)
"""

import os
import time
from os.path import exists
from os.path import sep as SEPARATOR
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

# pyright: reportMissingImports=false
# pylint: disable=import-error
Expand Down Expand Up @@ -131,14 +131,14 @@ def get_validation_dataset(
class InstrumentDatasetBuilder(object):
"""Instrument based filter and mapper provider."""

def __init__(self, parent, instrument) -> None:
def __init__(self, parent: Any, instrument: Any) -> None:
"""
Default constructor.
Parameters:
parent:
parent (Any):
Parent dataset builder.
instrument:
instrument (Any):
Target instrument.
"""
self._parent = parent
Expand All @@ -147,7 +147,7 @@ def __init__(self, parent, instrument) -> None:
self._min_spectrogram_key = f"min_{instrument}_spectrogram"
self._max_spectrogram_key = f"max_{instrument}_spectrogram"

def load_waveform(self, sample):
def load_waveform(self, sample: Dict) -> Dict:
"""Load waveform for given sample."""
return dict(
sample,
Expand All @@ -160,7 +160,7 @@ def load_waveform(self, sample):
),
)

def compute_spectrogram(self, sample):
def compute_spectrogram(self, sample: Dict) -> Dict:
"""Compute spectrogram of the given sample."""
return dict(
sample,
Expand All @@ -175,7 +175,7 @@ def compute_spectrogram(self, sample):
},
)

def filter_frequencies(self, sample):
def filter_frequencies(self, sample: Dict) -> Dict:
return dict(
sample,
**{
Expand All @@ -185,7 +185,7 @@ def filter_frequencies(self, sample):
},
)

def convert_to_uint(self, sample):
def convert_to_uint(self, sample: Dict) -> Dict:
"""Convert given sample from float to unit."""
return dict(
sample,
Expand All @@ -197,11 +197,11 @@ def convert_to_uint(self, sample):
),
)

def filter_infinity(self, sample):
def filter_infinity(self, sample: Dict) -> tf.Tensor:
"""Filter infinity sample."""
return tf.logical_not(tf.math.is_inf(sample[self._min_spectrogram_key]))

def convert_to_float32(self, sample):
def convert_to_float32(self, sample: Dict) -> Dict:
"""Convert given sample from unit to float."""
return dict(
sample,
Expand All @@ -214,7 +214,7 @@ def convert_to_float32(self, sample):
},
)

def time_crop(self, sample):
def time_crop(self, sample: Dict) -> Dict:
def start(sample):
"""mid_segment_start"""
return tf.cast(
Expand All @@ -235,14 +235,14 @@ def start(sample):
},
)

def filter_shape(self, sample):
def filter_shape(self, sample: Dict) -> bool:
"""Filter badly shaped sample."""
return check_tensor_shape(
sample[self._spectrogram_key],
(self._parent._T, self._parent._F, self._parent._n_channels),
)

def reshape_spectrogram(self, sample):
def reshape_spectrogram(self, sample: Dict) -> Dict:
"""Reshape given sample."""
return dict(
sample,
Expand Down Expand Up @@ -272,17 +272,6 @@ def __init__(
) -> None:
"""
Default constructor.
NOTE: Probably need for AudioAdapter.
Parameters:
audio_params (Dict):
Audio parameters to use.
audio_adapter (AudioAdapter):
Audio adapter to use.
audio_path (str):
random_seed (int):
chunk_duration (float):
"""
# Length of segment in frames (if fs=22050 and
# frame_step=512, then T=512 corresponds to 11.89s)
Expand Down Expand Up @@ -321,7 +310,7 @@ def check_parameters_compatibility(self):
"(for instance reducing T or frame_step or increasing chunk duration)."
)

def expand_path(self, sample):
def expand_path(self, sample: Dict) -> Dict:
"""Expands audio paths for the given sample."""
return dict(
sample,
Expand All @@ -333,15 +322,15 @@ def expand_path(self, sample):
},
)

def filter_error(self, sample):
def filter_error(self, sample: Dict) -> tf.Tensor:
"""Filter errored sample."""
return tf.logical_not(sample["waveform_error"])

def filter_waveform(self, sample):
def filter_waveform(self, sample: Dict) -> Dict:
"""Filter waveform from sample."""
return {k: v for k, v in sample.items() if not k == "waveform"}

def harmonize_spectrogram(self, sample):
def harmonize_spectrogram(self, sample: Dict) -> Dict:
"""Ensure same size for vocals and mix spectrograms."""

def _reduce(sample):
Expand All @@ -362,7 +351,7 @@ def _reduce(sample):
},
)

def filter_short_segments(self, sample):
def filter_short_segments(self, sample: Dict) -> tf.Tensor:
"""Filter out too short segment."""
return tf.reduce_any(
[
Expand All @@ -371,7 +360,7 @@ def filter_short_segments(self, sample):
]
)

def random_time_crop(self, sample):
def random_time_crop(self, sample: Dict) -> Dict:
"""Random time crop of 11.88s."""
return dict(
sample,
Expand All @@ -388,7 +377,7 @@ def random_time_crop(self, sample):
),
)

def random_time_stretch(self, sample):
def random_time_stretch(self, sample: Dict) -> Dict:
"""Randomly time stretch the given sample."""
return dict(
sample,
Expand All @@ -401,7 +390,7 @@ def random_time_stretch(self, sample):
),
)

def random_pitch_shift(self, sample):
def random_pitch_shift(self, sample: Dict) -> Dict:
"""Randomly pitch shift the given sample."""
return dict(
sample,
Expand All @@ -415,7 +404,7 @@ def random_pitch_shift(self, sample):
),
)

def map_features(self, sample):
def map_features(self, sample: Dict) -> Tuple[Dict, Dict]:
"""Select features and annotation of the given sample."""
input_ = {
f"{self._mix_name}_spectrogram": sample[f"{self._mix_name}_spectrogram"]
Expand Down
Loading

0 comments on commit 5adc8cf

Please sign in to comment.